Worker Exception & Timeout Refactor

Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Seth Howes <sethshowes@gmail.com>
This commit is contained in:
Matt Beton
2025-08-02 16:28:37 +01:00
committed by GitHub
parent 92c9688bf0
commit 1fe4ed3442
56 changed files with 2519 additions and 893 deletions

View File

@@ -3,13 +3,14 @@ name: Build and Release Exo macOS App
on:
push:
tags:
- 'v*' # Trigger only on version tags
- 'v*' # Trigger on version tags
branches:
- main # Also build on main branch for testing
- app-staging # Add app-staging for testing
pull_request:
branches:
- main # Test builds on PRs
- staging # Test builds on PRs to staging
- main # Build on PRs to main
jobs:
build-exov2-macos:
@@ -20,18 +21,6 @@ jobs:
with:
fetch-depth: 0
- name: Setup Rust (nightly)
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: nightly
components: rustfmt, clippy
default: true
- name: Set Rust toolchain override
run: |
rustup default nightly
cd rust && rustup override set nightly
- name: Install Go
uses: actions/setup-go@v5
with:
@@ -52,12 +41,6 @@ jobs:
uv python install
uv sync --locked --all-extras
- name: Build Rust Components
env:
RUSTFLAGS: "-A unused-imports -A dead-code -A unreachable-code"
run: |
just build-all
- name: Install Python Bindings
run: |
uv pip install dist/exo_pyo3_bindings-*.whl

43
configure_mlx.sh Normal file
View File

@@ -0,0 +1,43 @@
#!/usr/bin/env bash
# Get the total memory in MB
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
# Calculate 80% and TOTAL_MEM_GB-5GB in MB
EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100))
MINUS_5GB=$((($TOTAL_MEM_MB - 5120)))
# Calculate 70% and TOTAL_MEM_GB-8GB in MB
SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100))
MINUS_8GB=$((($TOTAL_MEM_MB - 8192)))
# Set WIRED_LIMIT_MB to higher value
if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then
WIRED_LIMIT_MB=$EIGHTY_PERCENT
else
WIRED_LIMIT_MB=$MINUS_5GB
fi
# Set WIRED_LWM_MB to higher value
if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then
WIRED_LWM_MB=$SEVENTY_PERCENT
else
WIRED_LWM_MB=$MINUS_8GB
fi
# Display the calculated values
echo "Total memory: $TOTAL_MEM_MB MB"
echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
# Apply the values with sysctl, but check if we're already root
if [ "$EUID" -eq 0 ]; then
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
else
# Try without sudo first, fall back to sudo if needed
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
fi

View File

@@ -483,25 +483,89 @@
}
.model-select {
background-color: var(--exo-medium-gray);
background: linear-gradient(135deg, #2a2a2a 0%, #3c3c3c 50%, #2a2a2a 100%);
color: var(--exo-light-gray);
border: 1px solid var(--exo-light-gray);
border-radius: 6px;
padding: 10px 12px;
font-size: 14px;
border: 2px solid rgba(255, 215, 0, 0.2);
border-radius: 12px;
padding: 14px 20px 14px 16px;
font-size: 15px;
font-family: var(--font-family);
font-weight: 500;
cursor: pointer;
transition: all 0.25s cubic-bezier(0.4, 0, 0.2, 1);
box-shadow:
0 4px 12px rgba(0, 0, 0, 0.25),
inset 0 1px 0 rgba(255, 255, 255, 0.12),
inset 0 -1px 0 rgba(0, 0, 0, 0.1);
position: relative;
appearance: none;
width: 100%;
min-height: 48px;
background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%23FFD700' d='M6 8.5L2.5 5h7z'/%3E%3C/svg%3E");
background-position: calc(100% - 16px) center;
background-size: 12px 12px;
background-repeat: no-repeat;
}
.model-select:hover {
background: linear-gradient(135deg, #363636 0%, #484848 50%, #363636 100%);
border-color: rgba(255, 215, 0, 0.5);
box-shadow:
0 6px 20px rgba(0, 0, 0, 0.3),
inset 0 1px 0 rgba(255, 255, 255, 0.15),
inset 0 -1px 0 rgba(0, 0, 0, 0.1),
0 0 0 1px rgba(255, 215, 0, 0.1);
transform: translateY(-2px);
}
.model-select:focus {
outline: none;
border-color: var(--exo-yellow);
box-shadow: 0 0 0 2px rgba(255, 215, 0, 0.2);
box-shadow:
0 0 0 4px rgba(255, 215, 0, 0.25),
0 8px 24px rgba(0, 0, 0, 0.4),
inset 0 1px 0 rgba(255, 255, 255, 0.2),
inset 0 -1px 0 rgba(0, 0, 0, 0.1);
background: linear-gradient(135deg, #404040 0%, #525252 50%, #404040 100%);
transform: translateY(-1px);
}
.model-select:active {
transform: translateY(0);
box-shadow:
0 2px 8px rgba(0, 0, 0, 0.3),
inset 0 1px 0 rgba(255, 255, 255, 0.1),
inset 0 2px 6px rgba(0, 0, 0, 0.2);
}
.model-select:disabled {
background: linear-gradient(135deg, #1a1a1a 0%, #222222 100%);
color: #555555;
border-color: #333333;
cursor: not-allowed;
transform: none;
box-shadow: inset 0 2px 6px rgba(0, 0, 0, 0.4);
background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%23555555' d='M6 8.5L2.5 5h7z'/%3E%3C/svg%3E");
}
.model-select option {
background-color: var(--exo-medium-gray);
background-color: var(--exo-dark-gray);
color: var(--exo-light-gray);
padding: 12px 16px;
border: none;
font-size: 14px;
font-weight: 500;
}
.model-select option:hover {
background-color: var(--exo-medium-gray);
color: var(--exo-yellow);
}
.model-select option:checked {
background-color: var(--exo-yellow);
color: var(--exo-black);
font-weight: 600;
}
.launch-button {
@@ -576,6 +640,7 @@
color: var(--exo-light-gray);
font-style: italic;
margin-top: 40px;
margin-bottom: 30px;
}
@@ -588,6 +653,13 @@
<!-- Sidebar -->
<div class="sidebar" id="instancesSidebar">
<div class="sidebar-header">
<h3>Running Instances</h3>
</div>
<div class="sidebar-content" id="instancesList">
<div class="no-instances">Loading instances...</div>
</div>
<div class="sidebar-header">
<h3>Launch Instance</h3>
</div>
@@ -601,13 +673,6 @@
<div id="launchStatus" class="launch-status"></div>
</div>
</div>
<div class="sidebar-header">
<h3>Running Instances</h3>
</div>
<div class="sidebar-content" id="instancesList">
<div class="no-instances">Loading instances...</div>
</div>
</div>
<div class="dashboard-header">

View File

@@ -29,7 +29,6 @@ def mx_barrier():
)
)
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -130,3 +129,18 @@ async def apply_chat_template(
)
return prompt
def mlx_force_oom(size: int = 40000) -> None:
"""
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
"""
mx.set_default_device(mx.gpu) # type: ignore
a = mx.random.uniform(shape=(size, size), dtype=mx.float32) # type: ignore
b = mx.random.uniform(shape=(size, size), dtype=mx.float32) # type: ignore
mx.eval(a, b) # type: ignore
c = mx.matmul(a, b) # type: ignore
d = mx.matmul(a, c) # type: ignore
e = mx.matmul(b, c) # type: ignore
f = mx.sigmoid(d + e) # type: ignore
mx.eval(f) # type: ignore

View File

@@ -32,6 +32,7 @@ from shared.types.events.commands import (
CommandType,
CreateInstanceCommand,
DeleteInstanceCommand,
TaskFinishedCommand,
)
from shared.types.events.components import EventFromEventLog
from shared.types.models import ModelMetadata
@@ -177,6 +178,11 @@ class API:
if event.chunk.finish_reason is not None:
yield "data: [DONE]"
finished = True
command = TaskFinishedCommand(
command_id=command_id
)
self.command_buffer.append(command)
return

View File

@@ -14,11 +14,14 @@ from shared.apply import apply
from shared.db.sqlite.config import EventLogConfig
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogManager
from shared.types.common import NodeId
from shared.types.common import CommandId, NodeId
from shared.types.events import (
Event,
Heartbeat,
InstanceDeleted,
TaskCreated,
TaskDeleted,
TopologyEdgeDeleted,
TopologyNodeCreated,
)
from shared.types.events.commands import (
@@ -26,6 +29,7 @@ from shared.types.events.commands import (
Command,
CreateInstanceCommand,
DeleteInstanceCommand,
TaskFinishedCommand,
)
from shared.types.state import State
from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
@@ -43,6 +47,7 @@ class Master:
self.command_buffer = command_buffer
self.global_events = global_events
self.worker_events = worker_events
self.command_task_mapping: dict[CommandId, TaskId] = {}
self.forwarder_supervisor = ForwarderSupervisor(
self.node_id,
forwarder_binary_path=forwarder_binary_path,
@@ -96,6 +101,8 @@ class Master:
task_params=next_command.request_params
)
))
self.command_task_mapping[next_command.command_id] = task_id
case DeleteInstanceCommand():
placement = get_instance_placements(next_command, self.state.topology, self.state.instances)
transition_events = get_transition_events(self.state.instances, placement)
@@ -104,6 +111,11 @@ class Master:
placement = get_instance_placements(next_command, self.state.topology, self.state.instances)
transition_events = get_transition_events(self.state.instances, placement)
next_events.extend(transition_events)
case TaskFinishedCommand():
next_events.append(TaskDeleted(
task_id=self.command_task_mapping[next_command.command_id]
))
del self.command_task_mapping[next_command.command_id]
await self.event_log_for_writes.append_events(next_events, origin=self.node_id)
# 2. get latest events
@@ -119,6 +131,24 @@ class Master:
self.state = apply(self.state, event_from_log)
self.logger.info(f"state: {self.state.model_dump_json()}")
# TODO: This can be done in a better place. But for now, we use this to check if any running instances have been broken.
write_events: list[Event] = []
if any([isinstance(event_from_log.event, TopologyEdgeDeleted) for event_from_log in events]):
connected_node_ids = set([x.node_id for x in self.state.topology.list_nodes()])
for instance_id, instance in self.state.instances.items():
delete = False
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
delete = True
break
if delete:
write_events.append(InstanceDeleted(
instance_id=instance_id
))
if write_events:
await self.event_log_for_writes.append_events(events=write_events, origin=self.node_id)
async def run(self):
self.state = await self._get_state_snapshot()

View File

@@ -41,7 +41,14 @@ def get_instance_placements(
raise ValueError("No cycles found with sufficient memory")
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
selected_cycle = max(smallest_cycles, key=lambda cycle: sum(node.node_profile.memory.ram_available for node in cycle if node.node_profile is not None))
selected_cycle = None
for cycle in smallest_cycles:
cycle_graph: Topology = topology.get_subgraph_from_nodes(cycle)
if cycle_graph.is_thunderbolt_cycle(cycle):
selected_cycle = cycle
break
if selected_cycle is None:
selected_cycle = max(smallest_cycles, key=lambda cycle: sum(node.node_profile.memory.ram_available for node in cycle if node.node_profile is not None))
shard_assignments = get_shard_assignments(command.model_meta, selected_cycle)

View File

@@ -83,6 +83,10 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
if not cycles:
return []
get_thunderbolt = False
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
get_thunderbolt = True
cycle = cycles[0]
hosts: list[Host] = []
for i in range(len(cycle)):
@@ -92,8 +96,10 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
for connection in cycle_digraph.list_connections():
if (connection.local_node_id == current_node.node_id and
connection.send_back_node_id == next_node.node_id):
if get_thunderbolt and not connection.is_thunderbolt():
continue
host = Host(
ip=connection.send_back_multiaddr.ipv4_address,
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port
)
hosts.append(host)

View File

@@ -7,6 +7,7 @@ import (
"log"
"strconv"
"sync"
"time"
"github.com/google/uuid"
"github.com/libp2p/go-libp2p/core/network"
@@ -18,6 +19,10 @@ var (
eventsDBPath string
eventsDB *sql.DB
eventsDBMu sync.Mutex
// Track connections to prevent duplicate events
connectionTracker = make(map[string]bool)
connTrackerMu sync.Mutex
)
// SetEventsDBPath sets the path to the events database
@@ -166,33 +171,44 @@ func (n *NotifeeHandler) Connected(net network.Network, conn network.Conn) {
localAddr := conn.LocalMultiaddr()
remoteAddr := conn.RemoteMultiaddr()
// Get the actual node IDs (not peer IDs)
// Check if we've already processed this connection
connKey := fmt.Sprintf("%s-%s", conn.LocalPeer(), remotePeer)
connTrackerMu.Lock()
if connectionTracker[connKey] {
connTrackerMu.Unlock()
log.Printf("Skipping duplicate connection event for %s", remotePeer)
return
}
connectionTracker[connKey] = true
connTrackerMu.Unlock()
// Get the local node ID
localNodeID := GetNodeId()
// For remote node, we need to extract from peer ID or use a mapping
// For now, we'll use the peer ID as a placeholder
// TODO: Implement proper node ID mapping/discovery
remoteNodeID := remotePeer.String()
// Create connection event
event := &TopologyEdgeCreated{
EventType: EventTypeTopologyEdgeCreated,
EventID: uuid.New().String(),
Edge: Connection{
LocalNodeID: localNodeID,
SendBackNodeID: remoteNodeID,
LocalMultiaddr: parseMultiaddr(localAddr),
SendBackMultiaddr: parseMultiaddr(remoteAddr),
ConnectionProfile: nil, // TODO: Add connection profiling if needed
},
}
// Write event to database
if err := writeEvent(EventTypeTopologyEdgeCreated, event); err != nil {
log.Printf("Failed to write edge created event: %v", err)
} else {
log.Printf("Wrote edge created event: %s -> %s", localNodeID, remoteNodeID)
}
// Asynchronously exchange node IDs and write event
go func() {
mapper := GetNodeIDMapper()
// Add a small delay to ensure both sides are ready
time.Sleep(100 * time.Millisecond)
// Exchange node IDs
if err := mapper.ExchangeNodeID(remotePeer); err != nil {
log.Printf("Failed to exchange node ID with %s: %v", remotePeer, err)
// Don't write event if we can't get the node ID
return
}
// Get the actual remote node ID
remoteNodeID, ok := mapper.GetNodeIDForPeer(remotePeer)
if !ok {
log.Printf("Node ID not found for peer %s after successful exchange", remotePeer)
return
}
// Write edge created event with correct node IDs
writeEdgeCreatedEvent(localNodeID, remoteNodeID, localAddr, remoteAddr)
}()
}
// Disconnected is called when a connection is closed
@@ -201,9 +217,27 @@ func (n *NotifeeHandler) Disconnected(net network.Network, conn network.Conn) {
localAddr := conn.LocalMultiaddr()
remoteAddr := conn.RemoteMultiaddr()
// Clear connection tracker
connKey := fmt.Sprintf("%s-%s", conn.LocalPeer(), remotePeer)
connTrackerMu.Lock()
delete(connectionTracker, connKey)
connTrackerMu.Unlock()
// Get the actual node IDs (not peer IDs)
localNodeID := GetNodeId()
remoteNodeID := remotePeer.String() // TODO: Implement proper node ID mapping
// Get the remote node ID from the mapper
mapper := GetNodeIDMapper()
remoteNodeID, ok := mapper.GetNodeIDForPeer(remotePeer)
if !ok {
// Don't write event if we don't have the node ID mapping
log.Printf("No node ID mapping found for disconnected peer %s, skipping event", remotePeer)
mapper.RemoveMapping(remotePeer)
return
}
// Clean up the mapping
mapper.RemoveMapping(remotePeer)
// Create disconnection event
event := &TopologyEdgeDeleted{
@@ -253,6 +287,27 @@ func parseMultiaddr(ma multiaddr.Multiaddr) Multiaddr {
return result
}
// writeEdgeCreatedEvent writes a topology edge created event
func writeEdgeCreatedEvent(localNodeID, remoteNodeID string, localAddr, remoteAddr multiaddr.Multiaddr) {
event := &TopologyEdgeCreated{
EventType: EventTypeTopologyEdgeCreated,
EventID: uuid.New().String(),
Edge: Connection{
LocalNodeID: localNodeID,
SendBackNodeID: remoteNodeID,
LocalMultiaddr: parseMultiaddr(localAddr),
SendBackMultiaddr: parseMultiaddr(remoteAddr),
ConnectionProfile: nil,
},
}
if err := writeEvent(EventTypeTopologyEdgeCreated, event); err != nil {
log.Printf("Failed to write edge created event: %v", err)
} else {
log.Printf("Wrote edge created event: %s -> %s", localNodeID, remoteNodeID)
}
}
// GetNotifee returns a singleton instance of the notifee handler
func GetNotifee() network.Notifiee {
return &NotifeeHandler{}

View File

@@ -433,6 +433,9 @@ func getNode(ctx context.Context) {
// Register event notifiee to track topology changes
node.Network().Notify(GetNotifee())
// Set up node ID mapper
GetNodeIDMapper().SetHost(node)
// Start a goroutine to periodically trigger mDNS discovery
go periodicMDNSDiscovery()

View File

@@ -0,0 +1,185 @@
package forwarder
import (
"bufio"
"context"
"encoding/json"
"fmt"
"log"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
)
const (
// NodeIDExchangeProtocol is the protocol ID for node ID exchange
NodeIDExchangeProtocol = "/forwarder/nodeid/1.0.0"
// Exchange timeout - balanced for reliability
exchangeTimeout = 5 * time.Second
)
// NodeIDMessage is the message format for node ID exchange
type NodeIDMessage struct {
NodeID string `json:"node_id"`
}
// NodeIDMapper manages the mapping between peer IDs and node IDs
type NodeIDMapper struct {
mu sync.RWMutex
peerToNode map[peer.ID]string
nodeToPeer map[string]peer.ID
host host.Host
}
var (
nodeIDMapper *NodeIDMapper
mapperOnce sync.Once
)
// GetNodeIDMapper returns the singleton NodeIDMapper instance
func GetNodeIDMapper() *NodeIDMapper {
mapperOnce.Do(func() {
nodeIDMapper = &NodeIDMapper{
peerToNode: make(map[peer.ID]string),
nodeToPeer: make(map[string]peer.ID),
}
})
return nodeIDMapper
}
// SetHost sets the libp2p host for the mapper
func (m *NodeIDMapper) SetHost(h host.Host) {
m.mu.Lock()
defer m.mu.Unlock()
m.host = h
// Set up the stream handler for incoming node ID exchanges
h.SetStreamHandler(NodeIDExchangeProtocol, m.handleNodeIDStream)
}
// GetNodeIDForPeer returns the node ID for a given peer ID
func (m *NodeIDMapper) GetNodeIDForPeer(peerID peer.ID) (string, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
nodeID, ok := m.peerToNode[peerID]
return nodeID, ok
}
// GetPeerIDForNode returns the peer ID for a given node ID
func (m *NodeIDMapper) GetPeerIDForNode(nodeID string) (peer.ID, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
peerID, ok := m.nodeToPeer[nodeID]
return peerID, ok
}
// SetMapping sets the mapping between a peer ID and node ID
func (m *NodeIDMapper) SetMapping(peerID peer.ID, nodeID string) {
m.mu.Lock()
defer m.mu.Unlock()
m.peerToNode[peerID] = nodeID
m.nodeToPeer[nodeID] = peerID
log.Printf("Mapped peer %s to node %s", peerID, nodeID)
}
// RemoveMapping removes the mapping for a peer
func (m *NodeIDMapper) RemoveMapping(peerID peer.ID) {
m.mu.Lock()
defer m.mu.Unlock()
if nodeID, ok := m.peerToNode[peerID]; ok {
delete(m.peerToNode, peerID)
delete(m.nodeToPeer, nodeID)
log.Printf("Removed mapping for peer %s (was node %s)", peerID, nodeID)
}
}
// ExchangeNodeID initiates a node ID exchange with a peer
func (m *NodeIDMapper) ExchangeNodeID(peerID peer.ID) error {
if m.host == nil {
return fmt.Errorf("host not set")
}
// Check if we already have the mapping
if _, ok := m.GetNodeIDForPeer(peerID); ok {
return nil // Already have the mapping
}
// Try up to 3 times with exponential backoff
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
// Exponential backoff: 100ms, 200ms, 400ms
time.Sleep(time.Duration(100<<uint(attempt-1)) * time.Millisecond)
}
ctx, cancel := context.WithTimeout(context.Background(), exchangeTimeout)
// Open a stream to the peer
stream, err := m.host.NewStream(ctx, peerID, NodeIDExchangeProtocol)
if err != nil {
cancel()
lastErr = fmt.Errorf("failed to open stream: %w", err)
continue
}
// Send our node ID
msg := NodeIDMessage{NodeID: GetNodeId()}
encoder := json.NewEncoder(stream)
if err := encoder.Encode(&msg); err != nil {
stream.Close()
cancel()
lastErr = fmt.Errorf("failed to send node ID: %w", err)
continue
}
// Read their node ID
decoder := json.NewDecoder(bufio.NewReader(stream))
var response NodeIDMessage
if err := decoder.Decode(&response); err != nil {
stream.Close()
cancel()
lastErr = fmt.Errorf("failed to read node ID: %w", err)
continue
}
stream.Close()
cancel()
// Store the mapping
m.SetMapping(peerID, response.NodeID)
return nil
}
return lastErr
}
// handleNodeIDStream handles incoming node ID exchange requests
func (m *NodeIDMapper) handleNodeIDStream(stream network.Stream) {
defer stream.Close()
peerID := stream.Conn().RemotePeer()
// Read their node ID
decoder := json.NewDecoder(bufio.NewReader(stream))
var msg NodeIDMessage
if err := decoder.Decode(&msg); err != nil {
log.Printf("Failed to read node ID from %s: %v", peerID, err)
return
}
// Store the mapping
m.SetMapping(peerID, msg.NodeID)
// Send our node ID back
response := NodeIDMessage{NodeID: GetNodeId()}
encoder := json.NewEncoder(stream)
if err := encoder.Encode(&response); err != nil {
log.Printf("Failed to send node ID to %s: %v", peerID, err)
return
}
}

View File

@@ -0,0 +1,111 @@
package forwarder
import (
"bufio"
"context"
"encoding/json"
"log"
"testing"
"time"
"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockNodeIDStreamHandler creates a stream handler that responds with a specific node ID
func mockNodeIDStreamHandler(nodeID string) func(stream network.Stream) {
return func(stream network.Stream) {
defer stream.Close()
peerID := stream.Conn().RemotePeer()
// Read their node ID
decoder := json.NewDecoder(bufio.NewReader(stream))
var msg NodeIDMessage
if err := decoder.Decode(&msg); err != nil {
log.Printf("Failed to read node ID from %s: %v", peerID, err)
return
}
// Send our node ID back
response := NodeIDMessage{NodeID: nodeID}
encoder := json.NewEncoder(stream)
if err := encoder.Encode(&response); err != nil {
log.Printf("Failed to send node ID to %s: %v", peerID, err)
return
}
}
}
func TestNodeIDExchange(t *testing.T) {
// Create two test hosts
h1, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
require.NoError(t, err)
defer h1.Close()
h2, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
require.NoError(t, err)
defer h2.Close()
// Set up node ID for host 1
SetNodeId("node-1")
mapper1 := GetNodeIDMapper()
mapper1.SetHost(h1)
// Set up host 2 with a mock handler that responds with "node-2"
h2.SetStreamHandler(NodeIDExchangeProtocol, mockNodeIDStreamHandler("node-2"))
// Connect the hosts
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), 3600)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err = h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
require.NoError(t, err)
// Exchange node IDs
err = mapper1.ExchangeNodeID(h2.ID())
require.NoError(t, err)
// Verify the mapping on host 1
nodeID, ok := mapper1.GetNodeIDForPeer(h2.ID())
assert.True(t, ok)
assert.Equal(t, "node-2", nodeID)
}
func TestNodeIDMapperOperations(t *testing.T) {
mapper := &NodeIDMapper{
peerToNode: make(map[peer.ID]string),
nodeToPeer: make(map[string]peer.ID),
}
// Test peer ID (simulated)
peerID := peer.ID("test-peer-id")
nodeID := "test-node-id"
// Set mapping
mapper.SetMapping(peerID, nodeID)
// Verify forward mapping
gotNodeID, ok := mapper.GetNodeIDForPeer(peerID)
assert.True(t, ok)
assert.Equal(t, nodeID, gotNodeID)
// Verify reverse mapping
gotPeerID, ok := mapper.GetPeerIDForNode(nodeID)
assert.True(t, ok)
assert.Equal(t, peerID, gotPeerID)
// Remove mapping
mapper.RemoveMapping(peerID)
// Verify removal
_, ok = mapper.GetNodeIDForPeer(peerID)
assert.False(t, ok)
_, ok = mapper.GetPeerIDForNode(nodeID)
assert.False(t, ok)
}

View File

@@ -39,7 +39,8 @@ members = [
"master",
"worker",
"shared",
"engines/*"
"engines/*",
"scripts"
]
[tool.uv.sources]

View File

@@ -1,25 +0,0 @@
import asyncio
from logging import Logger
from worker.main import get_node_id
from shared.types.common import NodeId
from shared.db.sqlite.event_log_manager import EventLogManager, EventLogConfig
async def main():
node_id: NodeId = get_node_id()
logger: Logger = Logger('worker_log')
event_log_manager: EventLogManager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
events = await event_log_manager.global_events.get_events_since(0)
for wrapped_event in events:
event = wrapped_event.event
event_type = type(event).__name__.replace('_', ' ').title()
attributes = ', '.join(f"{key}={value!r}" for key, value in vars(event).items())
print(f"{event_type}: {attributes}")
if __name__ == "__main__":
asyncio.run(main())

0
scripts/README.md Normal file
View File

30
scripts/pyproject.toml Normal file
View File

@@ -0,0 +1,30 @@
[project]
name = "exo-scripts"
version = "0.1.0"
description = "Scripts for the Exo project"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"exo-shared",
"huggingface_hub>=0.33.4",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build]
clean = true
[tool.hatch.build.targets.wheel]
packages = []
include = ["*"]
exclude = ["*.md", "pyproject.toml"]
[tool.hatch.build.targets.sdist]
packages = []
include = ["*"]
exclude = ["*.md", "pyproject.toml"]

516
scripts/read_events.py Normal file
View File

@@ -0,0 +1,516 @@
import asyncio
import curses
import time
import json
import argparse
import textwrap
import sys
from logging import Logger
from typing import List, Optional, Any, Sequence, Tuple
from shared.types.state import State
from shared.apply import apply
from shared.db.sqlite.event_log_manager import EventLogManager, EventLogConfig
from shared.types.events.components import EventFromEventLog
from shared.types.events import Event
# Globals
logger: Logger = Logger('helper_log')
event_log_manager: Optional[EventLogManager] = None
worker_mode: bool = False
# Worker-related event types
WORKER_EVENT_TYPES = {
'TaskCreated', 'TaskStateUpdated', 'TaskFailed', 'TaskDeleted',
'ChunkGenerated',
'InstanceCreated', 'InstanceDeleted', 'InstanceActivated', 'InstanceDeactivated', 'InstanceReplacedAtomically',
'RunnerStatusUpdated', 'RunnerDeleted'
}
async def init_db() -> None:
global event_log_manager
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
async def get_events_since(since: int) -> Sequence[EventFromEventLog[Event]]:
return await event_log_manager.global_events.get_events_since(since) # type: ignore[attr-defined, return-value]
async def load_all_events() -> List[EventFromEventLog[Event]]:
events: List[EventFromEventLog[Event]] = []
since = 0
while True:
new_events = await get_events_since(since)
if not new_events:
break
events.extend(new_events)
since += len(new_events)
return events
def compute_states(events: List[EventFromEventLog[Event]]) -> List[State]:
states: List[State] = [State()]
state = states[0]
for event in events:
state = apply(state, event)
states.append(state)
return states
def print_event(event: EventFromEventLog[Event]) -> None:
event_type_name = type(event.event).__name__
event_type = event_type_name.replace('_', ' ').title()
attributes = ', '.join(f"{key}={value!r}" for key, value in vars(event.event).items())
print(f"[{event.idx_in_log}] {event_type}: {attributes}")
async def non_tui_mode() -> None:
await init_db()
events = await load_all_events()
states = compute_states(events)
final_state = states[-1]
if worker_mode:
filtered_events = [e for e in events if type(e.event).__name__ in WORKER_EVENT_TYPES]
events = filtered_events
# Recompute states? But states are cumulative, so perhaps just print filtered events and full state, or filter state too.
state_dict = json.loads(final_state.model_dump_json())
filtered_state = {
'node_status': state_dict.get('node_status', {}),
'instances': state_dict.get('instances', {}),
'runners': state_dict.get('runners', {}),
'tasks': state_dict.get('tasks', {}),
'last_event_applied_idx': state_dict.get('last_event_applied_idx', 0)
}
print("Final State (filtered):")
print(json.dumps(filtered_state, indent=2))
else:
print("Final State:")
print(final_state.model_dump_json(indent=2))
print("\nEvents:")
for event in events:
print_event(event)
async def update_events(wrapped_events: List[EventFromEventLog[Event]], states: List[State], filtered_indices: Optional[List[int]] = None) -> bool:
last_since = len(wrapped_events)
new_wrapped = await get_events_since(last_since)
if new_wrapped:
last_len = len(wrapped_events)
for nw in new_wrapped:
state = states[-1]
new_state = apply(state, nw)
states.append(new_state)
wrapped_events.extend(new_wrapped)
if filtered_indices is not None:
for k in range(last_len, len(wrapped_events)):
if type(wrapped_events[k].event).__name__ in WORKER_EVENT_TYPES:
filtered_indices.append(k)
return True
return False
def draw_state(win: Any, state: State, height: int, width: int, worker_mode: bool, state_scroll: int) -> int:
win.clear()
state_dict = json.loads(state.model_dump_json())
if worker_mode:
filtered_state = {
'node_status': state_dict.get('node_status', {}),
'instances': state_dict.get('instances', {}),
'runners': state_dict.get('runners', {}),
'tasks': state_dict.get('tasks', {}),
'last_event_applied_idx': state_dict.get('last_event_applied_idx', 0)
}
state_pretty = json.dumps(filtered_state, indent=2)
else:
state_pretty = json.dumps(state_dict, indent=2)
lines = state_pretty.split('\n')
max_scroll = max(0, len(lines) - height)
current_scroll = min(state_scroll, max_scroll)
for i in range(height):
line_idx = current_scroll + i
if line_idx >= len(lines):
break
line = lines[line_idx]
y = i
x = 0
leading_spaces = len(line) - len(line.lstrip())
win.addstr(y, x, ' ' * leading_spaces)
x += leading_spaces
stripped = line.lstrip()
if stripped.startswith('"'):
end_key = stripped.find('": ')
if end_key != -1:
key_str = stripped[:end_key + 3] # include ":
win.addstr(y, x, key_str, curses.color_pair(3))
x += len(key_str)
value_str = stripped[end_key + 3:]
if value_str.startswith('"'):
color = 2
elif value_str.replace('.', '', 1).isdigit() or (value_str.startswith('-') and value_str[1:].replace('.', '', 1).isdigit()):
color = 4
elif value_str in ['true', 'false', 'null']:
color = 5
elif value_str.startswith('{') or value_str.startswith('[') or value_str.startswith('}') or value_str.startswith(']'):
color = 0
else:
color = 0
win.addstr(y, x, value_str, curses.color_pair(color))
else:
win.addstr(y, x, stripped)
else:
win.addstr(y, x, stripped)
win.refresh()
return current_scroll
def get_event_pairs(event: EventFromEventLog[Event]) -> List[Tuple[str, int]]:
pairs: List[Tuple[str, int]] = []
idx_str = f"[{event.idx_in_log}] "
pairs.append((idx_str, 5))
event_type_name = type(event.event).__name__
event_type = event_type_name.replace('_', ' ').title()
pairs.append((event_type, 1))
pairs.append((": ", 0))
attrs = vars(event.event)
first = True
for key, value in attrs.items():
if not first:
pairs.append((", ", 0))
first = False
pairs.append((key, 3))
pairs.append(("=", 0))
v_str = repr(value)
if isinstance(value, str):
color = 2
elif isinstance(value, (int, float)):
color = 4
elif isinstance(value, bool):
color = 5
else:
color = 6
pairs.append((v_str, color))
return pairs
def calculate_event_lines(pairs: List[Tuple[str, int]], win_width: int, subsequent_indent: int) -> int:
lines = 1
x = 0
for text, _ in pairs:
i = 0
while i < len(text):
remaining = win_width - x
part_len = min(len(text) - i, remaining)
i += part_len
x += part_len
if i < len(text):
lines += 1
x = subsequent_indent
return lines
def render_event(win: Any, start_y: int, pairs: List[Tuple[str, int]], is_bold: bool, win_width: int, subsequent_indent: int) -> int:
y = start_y
x = 0
for text, color in pairs:
attr = curses.color_pair(color) | (curses.A_BOLD if is_bold else 0)
i = 0
while i < len(text):
remaining = win_width - x
part_len = min(len(text) - i, remaining)
part = text[i:i + part_len]
try:
win.addstr(y, x, part, attr)
except curses.error:
pass
i += part_len
x += part_len
if i < len(text):
y += 1
if y >= win.getmaxyx()[0]:
return y
x = subsequent_indent
if x > 0:
y += 1
return y
def draw_events(win: Any, events_list: List[EventFromEventLog[Event]], current_events: int, height: int) -> None:
win.clear()
if len(events_list) == 0:
win.addstr(0, 0, "No events")
win.refresh()
return
win_width = win.getmaxyx()[1]
current_event = events_list[current_events]
current_pairs = get_event_pairs(current_event)
subsequent_indent = len(f"[{current_event.idx_in_log}] ")
lines_current = calculate_event_lines(current_pairs, win_width, subsequent_indent)
if lines_current > height:
render_event(win, 0, current_pairs, True, win_width, subsequent_indent)
win.refresh()
return
target_above = (height - lines_current) // 2
target_below = height - lines_current - target_above
# Collect previous events
prev_events: List[int] = []
remaining = target_above
i = current_events - 1
while i >= 0 and remaining > 0:
event = events_list[i]
pairs = get_event_pairs(event)
indent = len(f"[{event.idx_in_log}] ")
lines = calculate_event_lines(pairs, win_width, indent)
if lines <= remaining:
remaining -= lines
prev_events.append(i)
i -= 1
else:
break
prev_events.reverse()
# Collect next events
next_events: List[int] = []
remaining = target_below
j = current_events + 1
while j < len(events_list) and remaining > 0:
event = events_list[j]
pairs = get_event_pairs(event)
indent = len(f"[{event.idx_in_log}] ")
lines = calculate_event_lines(pairs, win_width, indent)
if lines <= remaining:
remaining -= lines
next_events.append(j)
j += 1
else:
break
# Calculate total lines
total_lines = lines_current
for idx in prev_events:
event = events_list[idx]
pairs = get_event_pairs(event)
indent = len(f"[{event.idx_in_log}] ")
total_lines += calculate_event_lines(pairs, win_width, indent)
for idx in next_events:
event = events_list[idx]
pairs = get_event_pairs(event)
indent = len(f"[{event.idx_in_log}] ")
total_lines += calculate_event_lines(pairs, win_width, indent)
padding = (height - total_lines) // 2 if total_lines < height else 0
y = padding
# Draw prev
for idx in prev_events:
event = events_list[idx]
pairs = get_event_pairs(event)
indent = len(f"[{event.idx_in_log}] ")
y = render_event(win, y, pairs, False, win_width, indent)
# Draw current
y = render_event(win, y, current_pairs, True, win_width, subsequent_indent)
# Draw next
for idx in next_events:
event = events_list[idx]
pairs = get_event_pairs(event)
indent = len(f"[{event.idx_in_log}] ")
y = render_event(win, y, pairs, False, win_width, indent)
win.refresh()
def draw_status(win: Any, realtime: bool, current: int, total_events: int) -> None:
win.clear()
mode = "Realtime" if realtime else "Timetravel"
win.addstr(0, 0, f"Mode: {mode} | Current event: {current} / {total_events} | Arrows: navigate events, [/]: scroll state, g: goto, r: toggle realtime, q: quit")
win.refresh()
def get_input(stdscr: Any, prompt: str) -> str:
curses.echo()
stdscr.addstr(0, 0, prompt)
stdscr.refresh()
input_str = stdscr.getstr(0, len(prompt), 20).decode('utf-8')
curses.noecho()
return input_str
def get_key(win: Any) -> Any:
ch = win.getch()
if ch == -1:
return -1
if ch == 27:
ch2 = win.getch()
if ch2 == -1:
return 27
if ch2 == 91:
ch3 = win.getch()
if ch3 == -1:
return -1
if ch3 == 65:
return curses.KEY_UP
if ch3 == 66:
return curses.KEY_DOWN
if ch3 == 53:
ch4 = win.getch()
if ch4 == 126:
return curses.KEY_PPAGE
if ch3 == 54:
ch4 = win.getch()
if ch4 == 126:
return curses.KEY_NPAGE
if ch3 == 49:
ch4 = win.getch()
if ch4 == -1:
return -1
if ch4 == 59:
ch5 = win.getch()
if ch5 == -1:
return -1
if ch5 == 53:
ch6 = win.getch()
if ch6 == -1:
return -1
if ch6 == 65:
return 'CTRL_UP'
if ch6 == 66:
return 'CTRL_DOWN'
return ch
def tui(stdscr: Any) -> None:
curses.start_color()
curses.init_pair(1, curses.COLOR_BLUE, curses.COLOR_BLACK)
curses.init_pair(2, curses.COLOR_GREEN, curses.COLOR_BLACK)
curses.init_pair(3, curses.COLOR_MAGENTA, curses.COLOR_BLACK)
curses.init_pair(4, curses.COLOR_YELLOW, curses.COLOR_BLACK)
curses.init_pair(5, curses.COLOR_CYAN, curses.COLOR_BLACK)
curses.init_pair(6, curses.COLOR_WHITE, curses.COLOR_BLACK)
curses.use_default_colors()
stdscr.timeout(100)
curses.curs_set(0)
wrapped_events: List[EventFromEventLog[Event]] = []
states: List[State] = [State()]
asyncio.run(init_db())
asyncio.run(update_events(wrapped_events, states)) # Initial load
filtered_indices: Optional[List[int]] = None
current_filtered: int = -1
current: int = -1
if worker_mode:
filtered_indices = [i for i in range(len(wrapped_events)) if type(wrapped_events[i].event).__name__ in WORKER_EVENT_TYPES]
current_filtered = len(filtered_indices) - 1 if filtered_indices else -1
else:
current = len(wrapped_events) - 1 if wrapped_events else -1
realtime: bool = False
last_update: float = time.time()
update_interval: float = 1.0
state_scroll: int = 0
while True:
height, width = stdscr.getmaxyx()
status_height = 1
pane_height = height - status_height
pane_width = width // 2
state_win = curses.newwin(pane_height, pane_width, 0, 0)
events_win = curses.newwin(pane_height, width - pane_width, 0, pane_width)
status_win = curses.newwin(status_height, width, pane_height, 0)
if worker_mode:
assert filtered_indices is not None
current_original = filtered_indices[current_filtered] if current_filtered >= 0 else -1
events_list = [wrapped_events[i] for i in filtered_indices]
current_events = current_filtered
else:
current_original = current
events_list = wrapped_events
current_events = current
state_idx = current_original + 1 if current_original >= 0 else 0
state_scroll = draw_state(state_win, states[state_idx], pane_height, pane_width, worker_mode, state_scroll)
draw_events(events_win, events_list, current_events, pane_height)
total_events = len(wrapped_events) - 1 if wrapped_events else -1
draw_status(status_win, realtime, current_original if worker_mode else current, total_events)
key = get_key(stdscr)
if key != -1:
if key == curses.KEY_UP:
if worker_mode and current_filtered > 0:
current_filtered -= 1
elif not worker_mode and current > 0:
current -= 1
elif key == 'CTRL_UP':
if worker_mode:
current_filtered = max(0, current_filtered - 5)
else:
current = max(0, current - 5)
elif key == curses.KEY_DOWN:
if worker_mode and current_filtered < len(filtered_indices) - 1: # type: ignore[arg-type]
current_filtered += 1
elif not worker_mode and current < len(wrapped_events) - 1:
current += 1
elif key == 'CTRL_DOWN':
if worker_mode:
current_filtered = min(len(filtered_indices) - 1, current_filtered + 5) # type: ignore[arg-type]
else:
current = min(len(wrapped_events) - 1, current + 5)
elif key == ord('['):
state_scroll = max(0, state_scroll - pane_height // 2)
elif key == ord(']'):
state_scroll += pane_height // 2 # clamped in draw_state
elif key == ord('q'):
break
elif key == ord('r'):
realtime = not realtime
if realtime:
if worker_mode:
current_filtered = len(filtered_indices) - 1 if filtered_indices else -1 # type: ignore[arg-type]
else:
current = len(wrapped_events) - 1 if wrapped_events else -1
state_scroll = 0
elif key == ord('g'):
stdscr.timeout(-1) # block for input
input_str = get_input(status_win, "Go to event: ")
try:
goto = int(input_str)
if worker_mode:
assert filtered_indices is not None
for i, orig in enumerate(filtered_indices):
if wrapped_events[orig].idx_in_log == goto:
current_filtered = i
state_scroll = 0
break
else:
for i in range(len(wrapped_events)):
if wrapped_events[i].idx_in_log == goto:
current = i
state_scroll = 0
break
except ValueError:
pass
stdscr.timeout(100)
status_win.clear()
status_win.refresh()
if realtime and time.time() - last_update > update_interval:
updated = asyncio.run(update_events(wrapped_events, states, filtered_indices if worker_mode else None))
if updated:
if worker_mode:
current_filtered = len(filtered_indices) - 1 # type: ignore[arg-type]
else:
current = len(wrapped_events) - 1
state_scroll = 0
last_update = time.time()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Read and display events from the event log')
parser.add_argument('--worker', action='store_true', help='Only show worker-related events (task, streaming, instance, runner status)')
args = parser.parse_args()
worker_mode = args.worker
if not sys.stdout.isatty():
asyncio.run(non_tui_mode())
else:
try:
curses.wrapper(tui)
except curses.error as e:
if "could not find terminal" in str(e):
print("Error: Could not find terminal. Falling back to non-TUI mode.")
asyncio.run(non_tui_mode())
else:
raise

12
scripts/test_download.py Normal file
View File

@@ -0,0 +1,12 @@
from worker.download.download_utils import *
async def main():
meta = await file_meta(
'mlx-community/DeepSeek-R1-4bit',
revision='main',
path='config.json',
redirected_location=None,
)
print(meta)
asyncio.run(main())

View File

@@ -140,10 +140,10 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_performance_measured(event: NodePerformanceMeasured, state: State) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {**state.node_profiles, event.node_id: event.node_profile}
state = state.model_copy(update={"node_profiles": new_profiles})
if not state.topology.contains_node(event.node_id):
# TODO: figure out why this is happening in the first place
return state
topology = copy.copy(state.topology)
if not topology.contains_node(event.node_id):
# TODO: figure out why this is happening in the first place
topology.add_node(Node(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(update={"topology": topology})
@@ -164,13 +164,6 @@ def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> Sta
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
opposite_edge = Connection(
local_node_id=event.edge.send_back_node_id,
send_back_node_id=event.edge.local_node_id,
local_multiaddr=event.edge.send_back_multiaddr,
send_back_multiaddr=event.edge.local_multiaddr
)
topology.add_connection(opposite_edge)
return state.model_copy(update={"topology": topology})
@event_apply.register(TopologyEdgeReplacedAtomically)

View File

@@ -20,6 +20,10 @@ EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring"
LIBP2P_WORKER_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
# lower bounds define timeouts for flops and memory bandwidth - these are the values for the M1 chip.
LB_TFLOPS = 2.3
LB_MEMBW_GBPS = 68
LB_DISK_GBPS = 1.5
# little helper function to get the name of the module that raised the error
def get_caller_module_name() -> str:

View File

@@ -14,7 +14,20 @@ class ModelCard(BaseModel):
metadata: ModelMetadata
MODEL_CARDS = {
MODEL_CARDS: dict[str, ModelCard] = {
"deepseek-v3-0324": ModelCard(
short_id="deepseek-v3-0324",
model_id="mlx-community/DeepSeek-v3-0324-8bit",
name="DeepSeek V3 fp8",
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-v3-0324-8bit",
pretty_name="DeepSeek V3 fp8",
storage_size_kilobytes=754998771712//1024,
n_layers=61,
),
),
"llama-3.3": ModelCard(
short_id="llama-3.3",
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",

View File

@@ -1,6 +1,7 @@
from typing import Annotated, Dict, Optional
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from pydantic import BaseModel, Field
@@ -8,7 +9,7 @@ from shared.types.models import ModelMetadata
from worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_exo_tmp,
ensure_models_dir,
)
@@ -43,14 +44,16 @@ class ConfigData(BaseModel):
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_exo_tmp())/model_id.replace("/", "--")
target_dir = (await ensure_models_dir())/str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(model_id, "main", "config.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(config_path, 'r') as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> int:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_exo_tmp())/model_id.replace("/", "--")
target_dir = (await ensure_models_dir())/str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(model_id, "main", "model.safetensors.index.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(index_path, 'r') as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())

View File

@@ -161,6 +161,22 @@ class Topology(TopologyProto):
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: list[Node]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:
continue
has_tb = False
for edge in self._graph.get_all_edge_data(rid, neighbor_rid):
if edge.is_thunderbolt():
has_tb = True
break
if not has_tb:
return False
return True
def _is_bridge(self, connection: Connection) -> bool:
"""Check if removing this connection will orphan any nodes from the master."""
if self.master_node_id is None:

View File

@@ -1,4 +1,4 @@
from ipaddress import IPv4Address
from ipaddress import IPv4Address, IPv6Address
from typing import Any, Self
from uuid import uuid4
@@ -29,7 +29,7 @@ class CommandId(ID):
class Host(BaseModel):
ip: IPv4Address
ip: IPv4Address | IPv6Address
port: int
def __str__(self) -> str:

View File

@@ -16,6 +16,7 @@ class CommandType(str, Enum):
CHAT_COMPLETION = "CHAT_COMPLETION"
CREATE_INSTANCE = "CREATE_INSTANCE"
DELETE_INSTANCE = "DELETE_INSTANCE"
TASK_FINISHED = "TASK_FINISHED"
class _BaseCommand[T: CommandType](BaseModel):
@@ -39,8 +40,12 @@ class DeleteInstanceCommand(_BaseCommand[CommandType.DELETE_INSTANCE]):
instance_id: InstanceId
class TaskFinishedCommand(_BaseCommand[CommandType.TASK_FINISHED]):
command_type: Literal[CommandType.TASK_FINISHED] = CommandType.TASK_FINISHED
Command = Annotated[
ChatCompletionCommand | CreateInstanceCommand | DeleteInstanceCommand,
ChatCompletionCommand | CreateInstanceCommand | DeleteInstanceCommand | TaskFinishedCommand,
Field(discriminator="command_type")
]

View File

@@ -1,5 +1,5 @@
import re
from ipaddress import IPv4Address
from ipaddress import IPv4Address, IPv6Address
from typing import ClassVar
from pydantic import BaseModel, computed_field, field_serializer, field_validator
@@ -25,6 +25,20 @@ class Multiaddr(BaseModel):
return v
@computed_field
@property
def address_type(self) -> str:
for pattern in self.PATTERNS:
if re.match(pattern, self.address):
return pattern.split('/')[1]
raise ValueError(f"Invalid multiaddr format: {self.address}")
@property
def ipv6_address(self) -> IPv6Address:
match = re.match(r'^/ip6/([0-9a-fA-F:]+)', self.address)
if not match:
raise ValueError(f"Invalid multiaddr format: {self.address}. Expected format like /ip6/::1/tcp/4001")
return IPv6Address(match.group(1))
@property
def ipv4_address(self) -> IPv4Address:
match = re.match(r'^/ip4/(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})', self.address)
@@ -32,11 +46,15 @@ class Multiaddr(BaseModel):
raise ValueError(f"Invalid multiaddr format: {self.address}. Expected format like /ip4/127.0.0.1/tcp/4001")
return IPv4Address(match.group(1))
@field_serializer("ipv4_address")
def serialize_ipv4_address(self, value: IPv4Address) -> str:
@computed_field
@property
def ip_address(self) -> IPv4Address | IPv6Address:
return self.ipv4_address if self.address_type == 'ip4' else self.ipv6_address
@field_serializer("ip_address")
def serialize_ipv4_address(self, value: IPv4Address | IPv6Address) -> str:
return str(value)
@computed_field
@property
def port(self) -> int:

View File

@@ -22,8 +22,8 @@ class Connection(BaseModel):
(
self.local_node_id,
self.send_back_node_id,
self.local_multiaddr.ipv4_address,
self.send_back_multiaddr.ipv4_address,
self.local_multiaddr.ip_address,
self.send_back_multiaddr.ip_address,
)
)
@@ -33,9 +33,12 @@ class Connection(BaseModel):
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.local_multiaddr.ipv4_address == other.local_multiaddr.ipv4_address
and self.send_back_multiaddr.ipv4_address == other.send_back_multiaddr.ipv4_address
and self.local_multiaddr.ip_address == other.local_multiaddr.ip_address
and self.send_back_multiaddr.ip_address == other.send_back_multiaddr.ip_address
)
def is_thunderbolt(self) -> bool:
return str(self.local_multiaddr.ip_address).startswith('169.254') and str(self.send_back_multiaddr.ip_address).startswith('169.254')
class Node(BaseModel):

View File

@@ -1,4 +1,5 @@
from enum import Enum
from typing import Optional
from shared.types.common import ID
@@ -14,3 +15,12 @@ class RunnerId(ID):
class NodeStatus(str, Enum):
Idle = "Idle"
Running = "Running"
class RunnerError(Exception):
"""Exception raised when the runner process encounters an error."""
def __init__(self, error_type: str, error_message: str, traceback: Optional[str] = None):
self.error_type = error_type
self.error_message = error_message
self.traceback = traceback
super().__init__(f"{error_type}: {error_message}")

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Annotated, Generic, Literal, TypeVar
from typing import Annotated, Generic, Literal, Optional, TypeVar
from pydantic import BaseModel, Field, TypeAdapter
@@ -24,6 +24,11 @@ class BaseShardMetadata(BaseModel, Generic[PartitionStrategyT]):
partition_strategy: PartitionStrategyT
device_rank: int
world_size: int
# Error handling; equivalent to monkey-patch, but we can't monkey-patch runner.py
# This is kinda annoying because it allocates memory in the ShardMetadata object. Can be rethought after Shanghai.
immediate_exception: bool = False
should_timeout: Optional[float] = None
class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline]]):

31
uv.lock generated
View File

@@ -15,6 +15,7 @@ members = [
"exo",
"exo-engine-mlx",
"exo-master",
"exo-scripts",
"exo-shared",
"exo-worker",
]
@@ -303,6 +304,21 @@ requires-dist = [
{ name = "uvicorn", specifier = ">=0.35.0" },
]
[[package]]
name = "exo-scripts"
version = "0.1.0"
source = { editable = "scripts" }
dependencies = [
{ name = "exo-shared", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[package.metadata]
requires-dist = [
{ name = "exo-shared", editable = "shared" },
{ name = "huggingface-hub", specifier = ">=0.33.4" },
]
[[package]]
name = "exo-shared"
version = "0.1.0"
@@ -365,6 +381,7 @@ dependencies = [
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[package.metadata]
@@ -373,6 +390,7 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "mlx", specifier = "==0.26.3" },
{ name = "mlx-lm", specifier = ">=0.25.3" },
{ name = "psutil", specifier = ">=7.0.0" },
]
[[package]]
@@ -840,6 +858,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f7/af/ab3c51ab7507a7325e98ffe691d9495ee3d3aa5f589afad65ec920d39821/protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e", size = 168724, upload-time = "2025-05-28T19:25:53.926Z" },
]
[[package]]
name = "psutil"
version = "7.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2a/80/336820c1ad9286a4ded7e845b2eccfcb27851ab8ac6abece774a6ff4d3de/psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456", size = 497003, upload-time = "2025-02-13T21:54:07.946Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ed/e6/2d26234410f8b8abdbf891c9da62bee396583f713fb9f3325a4760875d22/psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25", size = 238051, upload-time = "2025-02-13T21:54:12.36Z" },
{ url = "https://files.pythonhosted.org/packages/04/8b/30f930733afe425e3cbfc0e1468a30a18942350c1a8816acfade80c005c4/psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da", size = 239535, upload-time = "2025-02-13T21:54:16.07Z" },
{ url = "https://files.pythonhosted.org/packages/2a/ed/d362e84620dd22876b55389248e522338ed1bf134a5edd3b8231d7207f6d/psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91", size = 275004, upload-time = "2025-02-13T21:54:18.662Z" },
{ url = "https://files.pythonhosted.org/packages/bf/b9/b0eb3f3cbcb734d930fdf839431606844a825b23eaf9a6ab371edac8162c/psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34", size = 277986, upload-time = "2025-02-13T21:54:21.811Z" },
{ url = "https://files.pythonhosted.org/packages/eb/a2/709e0fe2f093556c17fbafda93ac032257242cabcc7ff3369e2cb76a97aa/psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993", size = 279544, upload-time = "2025-02-13T21:54:24.68Z" },
]
[[package]]
name = "pycparser"
version = "2.22"

View File

@@ -2,7 +2,6 @@ import asyncio
import hashlib
import os
import shutil
import tempfile
import time
import traceback
from datetime import timedelta
@@ -91,9 +90,6 @@ class RepoDownloadProgress(BaseModel):
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_HOME / "models" / model_id.replace("/", "--")
def exo_tmp() -> Path:
return Path(tempfile.gettempdir())/"exo"
async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir())/repo_id.replace("/", "--")
@@ -101,10 +97,6 @@ async def ensure_exo_home() -> Path:
await aios.makedirs(EXO_HOME, exist_ok=True)
return EXO_HOME
async def ensure_exo_tmp() -> Path:
await aios.makedirs(exo_tmp(), exist_ok=True)
return exo_tmp()
async def has_exo_home_read_access() -> bool:
try:
return await aios.access(EXO_HOME, os.R_OK)
@@ -146,7 +138,9 @@ async def seed_models(seed_dir: Union[str, Path]):
traceback.print_exc()
async def fetch_file_list_with_cache(repo_id: str, revision: str = "main", recursive: bool = False) -> List[FileListEntry]:
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
target_dir = (await ensure_models_dir())/"caches"/str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, 'r') as f:
return TypeAdapter(List[FileListEntry]).validate_json(await f.read())
@@ -198,22 +192,29 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
hasher.update(chunk)
return hasher.hexdigest()
async def file_meta(repo_id: str, revision: str, path: str, redirected_location: str | None = None) -> Tuple[int, str]:
# NOTE: huggingface broke the E-Tag so we can no longer assume E-Tag == sha256(file)
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path) if redirected_location is None else f"{get_hf_endpoint()}{redirected_location}"
headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session, session.head(url, headers=headers) as r:
if r.status == 307:
redirected_location = r.headers.get('Location')
return await file_meta(repo_id, revision, path, redirected_location)
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
etag = etag[1:-1]
return content_length, etag
async def file_meta(repo_id: str, revision: str, path: str, redirected_location: str | None = None) -> Tuple[int, str]:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path) if redirected_location is None else f"{get_hf_endpoint()}{redirected_location}"
headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session, session.head(url, headers=headers) as r:
if r.status == 307:
# Try to extract from X-Linked headers first (common for HF redirects)
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
if content_length > 0 and etag is not None:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
etag = etag[1:-1]
return content_length, etag
# If not available, recurse with the redirect
redirected_location = r.headers.get('Location')
return await file_meta(repo_id, revision, path, redirected_location)
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
etag = etag[1:-1]
return content_length, etag
async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
n_attempts = 30
@@ -291,7 +292,8 @@ def calculate_repo_progress(shard: ShardMetadata, repo_id: str, revision: str, f
)
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
target_dir = (await ensure_models_dir())/str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
async with aiofiles.open(index_file, 'r') as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())

View File

@@ -9,16 +9,17 @@ from shared.types.events import (
)
from shared.types.profiling import NodePerformanceProfile
from shared.types.worker.ops import (
ExecuteTaskOp,
RunnerOp,
)
from shared.utils import get_node_id_keypair
from shared.utils import Keypair, get_node_id_keypair
from worker.download.impl_shard_downloader import exo_shard_downloader
from worker.plan import plan
from worker.utils.profile import start_polling_node_metrics
from worker.worker import Worker
async def run(worker_state: Worker):
async def run(worker_state: Worker, logger: logging.Logger):
assert worker_state.global_events is not None
while True:
@@ -42,15 +43,26 @@ async def run(worker_state: Worker):
# run the op, synchronously blocking for now
if op is not None:
async for event in worker_state.execute_op(op):
await worker_state.event_publisher(event)
logger.info(f'Executing op {op}')
try:
async for event in worker_state.execute_op(op):
await worker_state.event_publisher(event)
except Exception as e:
if isinstance(op, ExecuteTaskOp):
generator = worker_state.fail_task(e, runner_id=op.runner_id, task_id=op.task.task_id)
else:
generator = worker_state.fail_runner(e, runner_id=op.runner_id)
async for event in generator:
await worker_state.event_publisher(event)
await asyncio.sleep(0.01)
async def main():
node_id_keypair = get_node_id_keypair()
node_id_keypair: Keypair = get_node_id_keypair()
node_id = NodeId(node_id_keypair.to_peer_id().to_base58())
logger: logging.Logger = logging.getLogger('worker_logger')
logger.setLevel(logging.DEBUG)
@@ -72,7 +84,7 @@ async def main():
worker = Worker(node_id, logger, shard_downloader, event_log_manager.worker_events, event_log_manager.global_events)
await run(worker)
await run(worker, logger)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -95,7 +95,8 @@ def spin_down_runners(
num_spundown_nodes = 0
for runner_id in instance.shard_assignments.runner_to_shard:
if isinstance(state_runners[runner_id], InactiveRunnerStatus) and \
if runner_id in state_runners and \
isinstance(state_runners[runner_id], InactiveRunnerStatus) and \
runner_id not in assigned_runners:
num_spundown_nodes += 1
# Suggested:

View File

@@ -9,7 +9,7 @@ dependencies = [
"huggingface_hub>=0.33.4",
"mlx==0.26.3",
"mlx-lm>=0.25.3",
"psutil>=7.0.0",
]
[build-system]

View File

@@ -34,7 +34,7 @@ async def runner_read_message() -> RunnerMessage:
line: bytes = await loop.run_in_executor(None, sys.stdin.buffer.readline)
if not line: # This seems to be what triggers when we don't clean up the runner neatly and leave the process dangling.
raise EOFError("No more data to read")
raise EOFError("No more data to read when reading runner message")
line = line.strip()
try:
@@ -66,7 +66,7 @@ async def supervisor_read_response(
line: str = line_bytes.decode("utf-8").strip()
if not line:
raise EOFError("No more data to read")
raise EOFError("No more data to read when reading response from runner")
try:
return RunnerResponseTypeAdapter.validate_json(line)

View File

@@ -10,7 +10,7 @@ import mlx.nn as nn
from mlx_lm.generate import stream_generate # type: ignore
from mlx_lm.tokenizer_utils import TokenizerWrapper
from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx
from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx, mlx_force_oom
from shared.openai_compat import FinishReason
from shared.types.tasks import ChatCompletionTaskParams
from shared.types.worker.commands_runner import (
@@ -73,7 +73,7 @@ async def _mlx_generate(
chat_task_data=task,
)
max_tokens = task.max_tokens or 100
max_tokens = task.max_tokens or 1000
generation_fn = partial(_generate_tokens, prompt, max_tokens)
future = loop.run_in_executor(mlx_executor, generation_fn)
@@ -105,6 +105,12 @@ async def main():
setup_message = ensure_type(init_message, SetupMessage)
model_shard_meta = setup_message.model_shard_meta
hosts = setup_message.hosts
# For testing - these are fake break conditions
if model_shard_meta.immediate_exception:
raise Exception('Fake exception - runner failed to spin up.')
if model_shard_meta.should_timeout:
await asyncio.sleep(model_shard_meta.should_timeout)
setup_start_time = time.time()
@@ -127,7 +133,12 @@ async def main():
# TODO: this is a hack, why are we only looking at the first message? should have a tokenizer
prompt = task.messages[0]
if prompt.content is not None and 'EXO RUNNER MUST FAIL' in prompt.content:
runner_print('raising exception')
raise Exception('Artificial runner exception - for testing purposes only.')
if prompt.content is not None and 'EXO RUNNER MUST OOM' in prompt.content:
mlx_force_oom()
if prompt.content is not None and 'EXO RUNNER MUST TIMEOUT' in prompt.content:
await asyncio.sleep(100)
# Generate responses using the actual MLX generation
async for generation_response in _mlx_generate(

View File

@@ -1,10 +1,12 @@
import asyncio
import contextlib
import sys
import traceback
from collections.abc import AsyncGenerator
from logging import Logger
from types import CoroutineType
from typing import Any, Callable
from typing import Any, Callable, Optional
import psutil
from shared.types.common import CommandId, Host
from shared.types.events.chunks import GenerationChunk, TokenChunk
@@ -12,7 +14,6 @@ from shared.types.tasks import ChatCompletionTaskParams, Task
from shared.types.worker.commands_runner import (
ChatTaskMessage,
ErrorResponse,
ExitMessage,
FinishedResponse,
GenerationResponse,
InitializedResponse,
@@ -20,12 +21,19 @@ from shared.types.worker.commands_runner import (
RunnerResponse,
SetupMessage,
)
from shared.types.worker.common import RunnerError
from shared.types.worker.shards import ShardMetadata
from worker.runner.communication import (
supervisor_read_response,
supervisor_write_message,
)
from worker.runner.utils import get_runner_command
from worker.runner.utils import (
get_init_timeout,
get_prefil_timeout,
get_runner_command,
get_token_generate_timeout,
get_weights_size_kb,
)
class RunnerSupervisor:
@@ -33,47 +41,52 @@ class RunnerSupervisor:
RunnerSupervisor manages the lifecycle of a runner subprocess for model inference.
Use the class method `create` to properly initialize an instance.
"""
# TODO: Logger.
def __init__(
self,
model_shard_meta: ShardMetadata,
hosts: list[Host],
runner_process: asyncio.subprocess.Process,
logger: Logger,
):
"""Private constructor. Use RunnerSupervisor.create() instead."""
self.model_shard_meta: ShardMetadata = model_shard_meta
self.hosts: list[Host] = hosts
self.runner_process: asyncio.subprocess.Process = runner_process
self.running: bool = True
self.stderr_task = asyncio.create_task(self._watch_stderr(logger))
self.running_task: asyncio.Task[None] = asyncio.create_task(
self._watch_runner()
)
self.logger = logger
self.stderr_buffer: list[str] = [] # Accumulate stderr lines
self.crash_detected: bool = False
self.returncode: int | None = None
self.stderr_outpu: str | None = None
@classmethod
async def create(
cls,
model_shard_meta: ShardMetadata,
hosts: list[Host],
logger: Logger
logger: Logger,
initialize_timeout: Optional[float] = None,
) -> "RunnerSupervisor":
"""
Create and initialize a RunnerSupervisor instance.
The .create() classmethod pattern is used to ensure the constructor is asynchronous.
"""
cmd: list[str] = get_runner_command()
runner_process: asyncio.subprocess.Process = (
await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=sys.stderr
stderr=asyncio.subprocess.PIPE,
)
)
print(f'{model_shard_meta=}')
logger.info(f'initializing mlx instance with {model_shard_meta=}')
await supervisor_write_message(
runner_process,
SetupMessage(
@@ -82,88 +95,159 @@ class RunnerSupervisor:
),
)
while True:
line: RunnerResponse | None = await supervisor_read_response(
runner_process
)
if line is None or isinstance(line, PrintResponse):
# print(line)
continue
elif isinstance(line, ErrorResponse):
raise Exception(line.error_type, line.error_message, line.traceback or "")
else:
assert isinstance(line, InitializedResponse)
logger.info(f'Runner initialized in {line.time_taken} seconds')
print(f'Runner initialized in {line.time_taken} seconds')
break
async def read_initialization_message() -> None:
while True:
line: RunnerResponse | None = await supervisor_read_response(
runner_process
)
if line is None:
continue
elif isinstance(line, PrintResponse):
logger.info(line)
continue
elif isinstance(line, ErrorResponse):
raise RunnerError(line.error_type, line.error_message, line.traceback or "")
elif isinstance(line, InitializedResponse):
assert isinstance(line, InitializedResponse)
logger.info(f'Runner initialized in {line.time_taken} seconds')
break
else:
raise AssertionError(f'Non-valid line read from runner during initialization: {line}')
if not initialize_timeout:
initialize_timeout = get_init_timeout(model_shard_meta)
await asyncio.wait_for(read_initialization_message(), timeout=initialize_timeout)
return cls(
model_shard_meta=model_shard_meta,
hosts=hosts,
runner_process=runner_process,
logger=logger,
)
async def astop(self) -> None:
async def terminate() -> None:
# Check if process is already dead before trying to terminate
if self.runner_process.returncode is None:
self.runner_process.terminate()
# Wait for the process to exit (or confirm it's already exited)
try:
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=1.0)
except asyncio.TimeoutError:
# If terminate didn't work, force kill
if self.runner_process.returncode is None:
self.runner_process.kill()
_ = await self.runner_process.wait()
# Cancel the stderr monitoring task
if not self.stderr_task.done():
self.stderr_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.stderr_task
if not self.healthy:
print("Runner process is not healthy, killing...")
await terminate()
print('terminated')
if self.runner_process.stdout is not None:
# Kill the process and all its children
await self._kill_process_tree()
# Wait to make sure that the model has been unloaded from memory
async def wait_for_memory_release() -> None:
required_memory_bytes = get_weights_size_kb(self.model_shard_meta) * 1024
start_time = asyncio.get_event_loop().time()
while True:
try:
line = await asyncio.wait_for(
self.runner_process.stdout.readline(), timeout=0.01
)
if not line:
break
print(f"Remaining stdout: {line.decode('utf-8').strip()}")
except asyncio.TimeoutError:
available_memory_bytes = psutil.virtual_memory().available
if available_memory_bytes >= required_memory_bytes:
break
if asyncio.get_event_loop().time() - start_time > 30.0:
self.logger.warning("Timeout waiting for memory release after 30 seconds")
break
await asyncio.sleep(0.1)
# Only try to send ExitMessage if process is still alive
if self.runner_process.returncode is None:
try:
# Give the process a moment to exit gracefully
await supervisor_write_message(
proc=self.runner_process, message=ExitMessage()
)
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1)
except asyncio.TimeoutError:
print("Runner process did not terminate, killing...")
await terminate()
except Exception:
# If we can't write to the process (e.g., broken pipe), it's probably already dead
pass
await wait_for_memory_release()
self.running = False
async def _kill_process_tree(self) -> None:
"""Kill the process and all its children forcefully."""
if self.runner_process.returncode is not None:
return # Process already dead
try:
# Get the main process
pid = self.runner_process.pid
# Find all child processes
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
# Kill all children first (bottom-up)
for child in reversed(children):
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
child.kill() # SIGKILL
# Kill the parent
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
parent.kill() # SIGKILL
except psutil.NoSuchProcess:
# Process already gone, try subprocess kill anyway
self.runner_process.kill()
# Wait for the subprocess to exit
try:
await asyncio.wait_for(self.runner_process.wait(), timeout=2.0)
except asyncio.TimeoutError:
self.logger.error(f"Process {pid} did not exit after kill signal")
except Exception as e:
self.logger.error(f"Error killing process tree: {e}")
async def _watch_runner(self) -> None:
_ = await self.runner_process.wait()
returncode = await self.runner_process.wait()
self.running = False
if returncode != 0:
self.crash_detected = True
self.returncode = returncode # Will be picked up by _watch_stderr too
async def _watch_stderr(self, logger: Logger) -> None:
assert self.runner_process.stderr is not None
while self.running:
try:
line_bytes = await self.runner_process.stderr.readline()
if not line_bytes:
break # EOF
line = line_bytes.decode('utf-8').strip()
self.stderr_buffer.append(line)
logger.error(f"Runner stderr: {line}")
# Detect common crash patterns (extend as needed, e.g., for OOM: "Killed" or "Out of memory")
self.crash_detected = True
self.stderr_output = "\n".join(self.stderr_buffer)
logger.critical(f"Runner crash detected: {self.stderr_output}")
# Don't raise here—let callers (e.g., stream_response) detect via healthy/returncode
except Exception as e:
logger.error(f"Error reading runner stderr: {e}")
break
# After EOF, inspect returncode for confirmation (Unix-like: negative == signal)
returncode = self.runner_process.returncode
if returncode is not None and returncode != 0:
self.crash_detected = True
self.returncode = returncode
self.stderr_output = "\n".join(self.stderr_buffer)
def _raise_if_crashed(self) -> None:
if self.crash_detected:
self.logger.error(f'Error {self.returncode}: {self.stderr_output}')
raise RunnerError(
error_type="RunnerCrash",
error_message=self.stderr_output,
traceback=traceback.format_exc(),
)
def __del__(self) -> None:
if self.running:
print(
"Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process."
"Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process tree."
)
with contextlib.suppress(ProcessLookupError):
self.runner_process.kill()
# Can't use async in __del__, so use psutil directly
try:
pid = self.runner_process.pid
if pid:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
for child in reversed(children):
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
child.kill()
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
parent.kill()
except Exception:
with contextlib.suppress(ProcessLookupError):
self.runner_process.kill()
@property
def healthy(self) -> bool:
@@ -178,7 +262,7 @@ class RunnerSupervisor:
async def stream_response(
self,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, # fyi this is async now
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, # fyi this is async now
) -> AsyncGenerator[GenerationChunk]:
"""
Streams a chat request from the model.
@@ -187,50 +271,52 @@ class RunnerSupervisor:
"""
if not self.healthy:
raise RuntimeError("Runner process was found to be dead")
task_params = task.task_params
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
await supervisor_write_message(
proc=self.runner_process,
message=ChatTaskMessage(
task_data=task_params,
),
)
# This is easy for now. If we need more reliability, the runner can have a new 'ready' message type.
if request_started_callback is not None:
await request_started_callback()
prefil_timeout = get_prefil_timeout(self.model_shard_meta)
token_timeout = get_token_generate_timeout(self.model_shard_meta)
timeout = prefil_timeout
while True:
line: RunnerResponse | None = await supervisor_read_response(
self.runner_process
)
if line is None:
continue
else:
match line:
case GenerationResponse(
text=text, token=token, finish_reason=finish_reason
):
yield TokenChunk(
command_id=CommandId(task.command_id),
idx=token,
model=self.model_shard_meta.model_meta.model_id,
text=text,
token_id=token,
finish_reason=finish_reason,
)
case InitializedResponse():
raise ValueError('Initialized Response read during streaming flow')
case FinishedResponse():
break
case PrintResponse(text=text):
print(f"runner printed: {text}")
case ErrorResponse(
error_type=error_type,
error_message=error_message,
traceback=traceback,
):
await self.astop()
raise Exception(error_type, error_message, traceback or "")
try:
line: RunnerResponse | None = await asyncio.wait_for(supervisor_read_response(
self.runner_process
), timeout=timeout)
if line is None:
continue
except (asyncio.TimeoutError, EOFError) as e:
self._raise_if_crashed()
raise RunnerError(
error_type=type(e).__name__,
error_message=str(e),
traceback="",
) from e
match line:
case GenerationResponse():
yield TokenChunk(
command_id=CommandId(task.command_id),
idx=line.token,
model=self.model_shard_meta.model_meta.model_id,
text=line.text,
token_id=line.token,
finish_reason=line.finish_reason,
)
timeout = token_timeout
case InitializedResponse():
raise ValueError('Initialized Response read during streaming flow')
case FinishedResponse():
break
case PrintResponse():
# print(f"runner printed: {line.text}")
self.logger.info(f"runner printed: {line.text}")
case ErrorResponse():
await self.astop()
raise RunnerError(line.error_type, line.error_message, line.traceback or "")

View File

@@ -1,6 +1,34 @@
import sys
from shared.constants import LB_DISK_GBPS, LB_MEMBW_GBPS, LB_TFLOPS
from shared.types.worker.shards import ShardMetadata
def get_runner_command() -> list[str]:
python = sys.executable
return [python, "-m", "worker.runner.runner"]
def get_weights_size_kb(model_shard_meta: ShardMetadata) -> float:
return (model_shard_meta.end_layer - model_shard_meta.start_layer) / model_shard_meta.n_layers * model_shard_meta.model_meta.storage_size_kilobytes
def get_init_timeout(model_shard_meta: ShardMetadata) -> float:
weights_size_kb = get_weights_size_kb(model_shard_meta)
kbps_read = 1024 * 1024 * LB_DISK_GBPS / 3
return weights_size_kb / kbps_read + 2.0
def get_prefil_timeout(model_shard_meta: ShardMetadata) -> float:
weights_size_gb = get_weights_size_kb(model_shard_meta) / (1024 * 1024)
tokens = 1000 # constant for now - the prompt is only tokenized in the device...
prompt_gflops = tokens * weights_size_gb * 2
return LB_TFLOPS / (1024 * prompt_gflops) * 3 + 10.0
def get_token_generate_timeout(model_shard_meta: ShardMetadata) -> float:
weights_size_kb = get_weights_size_kb(model_shard_meta)
kbps_read = 1024 * 1024 * LB_MEMBW_GBPS / 3
return weights_size_kb / kbps_read + 2.0

View File

@@ -112,7 +112,6 @@ def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], h
@pytest.fixture
def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
"""Creates ChatCompletionParams with the given message"""
return ChatCompletionTaskParams(
model="gpt-4",
messages=[ChatCompletionMessage(role="user", content=user_message)],
@@ -121,19 +120,19 @@ def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
@pytest.fixture
def chat_completion_task(completion_create_params: ChatCompletionTaskParams):
def _chat_completion_task(instance_id: Optional[InstanceId] = None, task_id: Optional[TaskId] = None) -> ChatCompletionTask:
if instance_id is None:
instance_id = INSTANCE_1_ID
if task_id is None:
task_id = TASK_1_ID
def _chat_completion_task(
instance_id: Optional[InstanceId] = None,
task_id: Optional[TaskId] = None,
user_message: str = "Hello"
) -> ChatCompletionTask:
resolved_instance_id = instance_id if instance_id is not None else INSTANCE_1_ID
resolved_task_id = task_id if task_id is not None else TASK_1_ID
return ChatCompletionTask(
task_id=task_id,
task_id=resolved_task_id,
command_id=COMMAND_1_ID,
instance_id=instance_id,
instance_id=resolved_instance_id,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=completion_create_params
)
return _chat_completion_task

View File

View File

@@ -1,28 +1,46 @@
## Tests for worker state handlers
import asyncio
from typing import Callable
import pytest
from shared.types.events import (
RunnerStatusUpdated,
TaskFailed,
TaskStateUpdated,
)
from shared.types.tasks import ChatCompletionTask, TaskStatus
from shared.types.tasks import ChatCompletionTask
from shared.types.worker.common import RunnerError
from shared.types.worker.instances import Instance
from shared.types.worker.ops import (
ExecuteTaskOp,
)
from shared.types.worker.runners import (
FailedRunnerStatus,
RunningRunnerStatus,
RunnerUpOp,
)
from worker.main import Worker
from worker.tests.constants import RUNNER_1_ID
from worker.tests.test_handlers.utils import read_events_op
@pytest.mark.asyncio
async def test_runner_up_fails(
worker_with_assigned_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask]):
worker, _ = worker_with_assigned_runner
worker.assigned_runners[RUNNER_1_ID].shard_metadata.immediate_exception = True
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
with pytest.raises(RunnerError):
await read_events_op(worker, runner_up_op)
@pytest.mark.asyncio
async def test_runner_up_timeouts(
worker_with_assigned_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask]):
worker, _ = worker_with_assigned_runner
worker.assigned_runners[RUNNER_1_ID].shard_metadata.should_timeout = 10
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
with pytest.raises(asyncio.TimeoutError):
await read_events_op(worker, runner_up_op)
@pytest.mark.asyncio
async def test_execute_task_fails(
worker_with_running_runner: tuple[Worker, Instance],
@@ -38,24 +56,27 @@ async def test_execute_task_fails(
task=task
)
events = await read_events_op(worker, execute_task_op)
with pytest.raises(RunnerError):
await read_events_op(worker, execute_task_op)
assert len(events) == 5
@pytest.mark.asyncio
async def test_execute_task_timeouts(
worker_with_running_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask]):
worker, _ = worker_with_running_runner
print(events)
task = chat_completion_task()
messages = task.task_params.messages
messages[0].content = 'Artificial prompt: EXO RUNNER MUST TIMEOUT'
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, RunningRunnerStatus) # It tried to start.
execute_task_op = ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=task
)
assert isinstance(events[1], TaskStateUpdated)
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
with pytest.raises(RunnerError): # At the moment this is a RunnerError that says 'TimeoutError'.
await read_events_op(worker, execute_task_op)
assert isinstance(events[2], TaskStateUpdated)
assert events[2].task_status == TaskStatus.FAILED # Task marked as failed.
assert isinstance(events[3], TaskFailed)
assert isinstance(events[4], RunnerStatusUpdated)
assert isinstance(events[4].runner_status, FailedRunnerStatus) # It should have failed.
# TODO: Much more to do here!
# TODO: Much more to do here!
# runner assigned download stuff

View File

View File

@@ -29,7 +29,7 @@ def worker_running(logger: Logger) -> Callable[[NodeId], Awaitable[tuple[Worker,
shard_downloader = NoopShardDownloader()
worker = Worker(node_id, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker))
asyncio.create_task(run(worker, logger))
return worker, global_events

View File

@@ -1,21 +1,36 @@
import asyncio
from typing import Tuple
from typing import Callable, Optional, Tuple, TypeVar
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.types.events import ChunkGenerated, TaskStateUpdated
from shared.types.events.chunks import TokenChunk
from shared.types.tasks import TaskStatus
from shared.types.tasks import TaskId, TaskStatus
async def read_streaming_response(global_events: AsyncSQLiteEventStorage) -> Tuple[bool, bool, str]:
async def read_streaming_response(global_events: AsyncSQLiteEventStorage, filter_task: Optional[TaskId] = None) -> Tuple[bool, bool, str]:
# Read off all events - these should be our GenerationChunk events
seen_task_started, seen_task_finished = 0, 0
response_string = ''
finish_reason: str | None = None
idx = 0
if not filter_task:
idx = await global_events.get_last_idx()
else:
found = False
idx = 0
while not found:
events = await global_events.get_events_since(idx)
for event in events:
if isinstance(event.event, TaskStateUpdated) and event.event.task_status == TaskStatus.RUNNING and event.event.task_id == filter_task:
found = True
idx = event.idx_in_log - 1
break
print(f'START IDX {idx}')
while not finish_reason:
events = await global_events.get_events_since(idx)
if len(events) == 0:
@@ -41,4 +56,26 @@ async def read_streaming_response(global_events: AsyncSQLiteEventStorage) -> Tup
print(f'event log: {await global_events.get_events_since(0)}')
return seen_task_started == 1, seen_task_finished == 1, response_string
return seen_task_started == 1, seen_task_finished == 1, response_string
T = TypeVar("T")
async def until_event_with_timeout(
global_events: AsyncSQLiteEventStorage,
event_type: type[T],
multiplicity: int = 1,
condition: Callable[[T], bool] = lambda x: True,
) -> None:
idx = await global_events.get_last_idx()
times_seen = 0
while True:
events = await global_events.get_events_since(idx)
if events:
for wrapped_event in events:
if isinstance(wrapped_event.event, event_type) and condition(wrapped_event.event):
times_seen += 1
if times_seen >= multiplicity:
return
idx = events[-1].idx_in_log
await asyncio.sleep(0.01)

View File

@@ -1,351 +0,0 @@
import asyncio
from logging import Logger
from typing import Awaitable, Callable
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import Host, NodeId
from shared.types.events import (
InstanceCreated,
InstanceDeleted,
RunnerDeleted,
RunnerStatusUpdated,
TaskCreated,
)
from shared.types.events.chunks import TokenChunk
from shared.types.models import ModelId
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import (
Instance,
InstanceStatus,
ShardAssignments,
)
from shared.types.worker.runners import (
DownloadingRunnerStatus,
# RunningRunnerStatus,
FailedRunnerStatus,
InactiveRunnerStatus,
LoadedRunnerStatus,
)
from shared.types.worker.shards import PipelineShardMetadata
from worker.common import AssignedRunner
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import run
from worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
TASK_2_ID,
)
from worker.tests.test_integration.integration_utils import (
read_streaming_response,
)
from worker.worker import Worker
async def test_runner_assigned(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
):
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.INACTIVE
await global_events.append_events(
[
InstanceCreated(
instance=instance_value
)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.1)
# Ensure the worker has taken the correct action
assert len(worker.assigned_runners) == 1
assert RUNNER_1_ID in worker.assigned_runners
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, InactiveRunnerStatus)
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len(events) >= 3 # len(events) is 4 if it's already downloaded. It is > 4 if there have to be download events.
assert isinstance(events[1].event, RunnerStatusUpdated)
assert isinstance(events[1].event.runner_status, DownloadingRunnerStatus)
assert isinstance(events[-1].event, RunnerStatusUpdated)
assert isinstance(events[-1].event.runner_status, InactiveRunnerStatus)
# Ensure state is correct
assert isinstance(worker.state.runners[RUNNER_1_ID], InactiveRunnerStatus)
async def test_runner_assigned_active(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
await global_events.append_events(
[
InstanceCreated(
instance=instance_value
)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)
assert len(worker.assigned_runners) == 1
assert RUNNER_1_ID in worker.assigned_runners
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, LoadedRunnerStatus)
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len(events) >= 4 # len(events) is 5 if it's already downloaded. It is > 5 if there have to be download events.
assert isinstance(events[1].event, RunnerStatusUpdated)
assert isinstance(events[1].event.runner_status, DownloadingRunnerStatus)
assert isinstance(events[-2].event, RunnerStatusUpdated)
assert isinstance(events[-2].event.runner_status, InactiveRunnerStatus)
assert isinstance(events[-1].event, RunnerStatusUpdated)
assert isinstance(events[-1].event.runner_status, LoadedRunnerStatus)
# Ensure state is correct
assert isinstance(worker.state.runners[RUNNER_1_ID], LoadedRunnerStatus)
# Ensure that the runner has been created and it can stream tokens.
supervisor = next(iter(worker.assigned_runners.values())).runner
assert supervisor is not None
assert supervisor.healthy
full_response = ''
async for chunk in supervisor.stream_response(task=chat_completion_task(INSTANCE_1_ID, TASK_1_ID)):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
assert "tokyo" in full_response.lower(), (
f"Expected 'Tokyo' in response, but got: {full_response}"
)
async def test_runner_assigned_wrong_node(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
):
worker, global_events = await worker_running(NODE_A)
instance_value = instance(INSTANCE_1_ID, NODE_B, RUNNER_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance_value
)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.1)
assert len(worker.assigned_runners) == 0
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len(events) == 1
# No RunnerStatusUpdated event should be emitted
# Ensure state is correct
assert len(worker.state.runners) == 0
async def test_runner_unassigns(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
):
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
await global_events.append_events(
[
InstanceCreated(
instance=instance_value
)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)
# already tested by test_runner_assigned_active
assert len(worker.assigned_runners) == 1
assert RUNNER_1_ID in worker.assigned_runners
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, LoadedRunnerStatus)
# Ensure the correct events have been emitted (creation)
events = await global_events.get_events_since(0)
assert len(events) >= 4
assert isinstance(events[-1].event, RunnerStatusUpdated)
assert isinstance(events[-1].event.runner_status, LoadedRunnerStatus)
# Ensure state is correct
assert isinstance(worker.state.runners[RUNNER_1_ID], LoadedRunnerStatus)
await global_events.append_events(
[
InstanceDeleted(instance_id=instance_value.instance_id)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)
assert len(worker.assigned_runners) == 0
# Ensure the correct events have been emitted (deletion)
events = await global_events.get_events_since(0)
assert isinstance(events[-1].event, RunnerDeleted)
# After deletion, runner should be removed from state.runners
assert len(worker.state.runners) == 0
async def test_runner_respawn(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
shard_downloader = NoopShardDownloader()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker1))
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker2))
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
await asyncio.sleep(0.1)
idx = await global_events.get_last_idx()
assigned_runner: AssignedRunner = worker1.assigned_runners[RUNNER_1_ID]
assert assigned_runner.runner is not None
assigned_runner.runner.runner_process.kill()
# Wait for the process to actually be detected as dead or cleaned up
for _ in range(100): # Wait up to 1 second
await asyncio.sleep(0.01)
# The worker may clean up the runner (set to None) when it detects it's dead
if assigned_runner.runner and not assigned_runner.runner.healthy:
break
else:
raise AssertionError("Runner should have been detected as unhealthy or cleaned up after kill()")
await asyncio.sleep(5.0)
events = await global_events.get_events_since(idx)
# assert len(events) == 2
assert isinstance(events[0].event, RunnerStatusUpdated)
assert isinstance(events[0].event.runner_status, FailedRunnerStatus)
assert isinstance(events[1].event, RunnerStatusUpdated)
assert isinstance(events[1].event.runner_status, InactiveRunnerStatus)
assert events[1].event.runner_id == RUNNER_2_ID
assert isinstance(events[2].event, RunnerStatusUpdated)
assert isinstance(events[2].event.runner_status, InactiveRunnerStatus)
assert events[2].event.runner_id == RUNNER_1_ID
for event in [events[3].event, events[4].event]:
assert isinstance(event, RunnerStatusUpdated)
assert isinstance(event.runner_status, LoadedRunnerStatus)
task = chat_completion_task(INSTANCE_1_ID, TASK_2_ID)
await global_events.append_events(
[
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
await asyncio.sleep(0.1)
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(1.0)

View File

@@ -62,6 +62,7 @@ async def test_runner_inference(
origin=MASTER_NODE_ID
)
# TODO: This needs to get fixed - sometimes it misses the 'starting' event.
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
@@ -93,10 +94,10 @@ async def test_2_runner_inference(
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker1))
asyncio.create_task(run(worker1, logger))
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker2))
asyncio.create_task(run(worker2, logger))
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
@@ -171,10 +172,10 @@ async def test_2_runner_multi_message(
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker1))
asyncio.create_task(run(worker1, logger))
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker2))
asyncio.create_task(run(worker2, logger))
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')

View File

@@ -17,6 +17,7 @@ from shared.types.events import (
TaskCreated,
TaskStateUpdated,
)
from shared.types.events._events import TaskFailed
from shared.types.events.chunks import GenerationChunk, TokenChunk
from shared.types.tasks import Task, TaskId, TaskStatus
from shared.types.worker.common import InstanceId, RunnerId
@@ -34,6 +35,7 @@ from worker.tests.constants import (
RUNNER_1_ID,
TASK_1_ID,
)
from worker.tests.test_integration.integration_utils import until_event_with_timeout
@pytest.fixture
@@ -41,14 +43,13 @@ def user_message():
"""Override this fixture in tests to customize the message"""
return "Who is the longest ruling monarch of England?"
# TODO: Make this all monkeypatched instead.
async def test_stream_response_failed_always(
monkeypatch: MonkeyPatch,
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
) -> None:
_, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
@@ -74,7 +75,7 @@ async def test_stream_response_failed_always(
origin=MASTER_NODE_ID
)
await asyncio.sleep(5.)
await until_event_with_timeout(global_events, InstanceDeleted)
events = await global_events.get_events_since(0)
@@ -133,7 +134,7 @@ async def test_stream_response_failed_once(
origin=MASTER_NODE_ID
)
await asyncio.sleep(5.)
await until_event_with_timeout(global_events, ChunkGenerated, 1, condition=lambda x: isinstance(x.chunk, TokenChunk) and x.chunk.finish_reason is not None)
# TODO: The ideal with this test is if we had some tooling to scroll through the state, and say
# 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero'
@@ -179,65 +180,41 @@ async def test_stream_response_failed_once(
await asyncio.sleep(0.3)
# async def test_stream_response_timeout(
# monkeypatch: MonkeyPatch,
# worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
# instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
# chat_completion_task: Callable[[InstanceId, TaskId], Task]
# ):
# async def mock_stream_response(
# self: RunnerSupervisor,
# task: Task,
# request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
# ) -> AsyncGenerator[GenerationChunk]:
# # TODO: Also a test where we yield a few chunks and then time out.
# print('sleeping starting')
# await asyncio.sleep(4.)
# print('sleeping finished')
# return
# yield
async def test_stream_response_timeout(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
_, global_events = await worker_running(NODE_A)
# monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
# worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
# instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
# instance_value.instance_type = InstanceStatus.ACTIVE
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = 'EXO RUNNER MUST TIMEOUT'
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task)
],
origin=MASTER_NODE_ID
)
# task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
# await global_events.append_events(
# [
# InstanceCreated(instance=instance_value),
# TaskCreated(task_id=task.task_id, task=task)
# ],
# origin=MASTER_NODE_ID
# )
await until_event_with_timeout(global_events, TaskFailed, multiplicity=3)
# await asyncio.sleep(7.)
events = await global_events.get_events_since(0)
print(events)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 3
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_message.lower()]) == 3
# # as we reset the failures back to zero when we have a successful inference.
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
# # print('ASSERTION ERR:')
# # print(worker.assigned_runners[RUNNER_1_ID].failures[1][1])
# assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
# assert worker.state.tasks[TASK_1_ID].error_type is None
# assert worker.state.tasks[TASK_1_ID].error_message is None
# events = await global_events.get_events_since(0)
# print(events)
# assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
# assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
# assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1
# await global_events.append_events(
# [
# InstanceDeleted(
# instance_id=instance_value.instance_id,
# ),
# ],
# origin=MASTER_NODE_ID
# )
# await asyncio.sleep(0.3)
await asyncio.sleep(0.3)

View File

@@ -0,0 +1,85 @@
import asyncio
from typing import Awaitable, Callable
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.types.common import NodeId
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.types.events import (
InstanceCreated,
InstanceDeleted,
RunnerStatusUpdated,
)
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import (
Instance,
InstanceStatus,
)
from shared.types.worker.runners import (
FailedRunnerStatus,
)
from worker.main import Worker
from worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
RUNNER_1_ID,
)
from worker.tests.test_integration.integration_utils import until_event_with_timeout
async def test_runner_spinup_exception(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
):
_, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].immediate_exception = True
await global_events.append_events(
[
InstanceCreated(
instance=instance_value
)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(5.0)
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
assert any([isinstance(x.event, InstanceDeleted) for x in events])
async def test_runner_spinup_timeout(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
):
_, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10
await global_events.append_events(
[
InstanceCreated(
instance=instance_value
)
],
origin=MASTER_NODE_ID
)
await until_event_with_timeout(global_events, RunnerStatusUpdated, multiplicity=3, condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus))
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
assert any([isinstance(x.event, InstanceDeleted) for x in events])

View File

@@ -0,0 +1,85 @@
import asyncio
from typing import Awaitable, Callable
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.types.common import NodeId
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.types.events import (
InstanceCreated,
InstanceDeleted,
RunnerStatusUpdated,
)
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import (
Instance,
InstanceStatus,
)
from shared.types.worker.runners import (
FailedRunnerStatus,
)
from worker.main import Worker
from worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
RUNNER_1_ID,
)
from worker.tests.test_integration.integration_utils import until_event_with_timeout
async def test_runner_spinup_exception(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
):
_, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].immediate_exception = True
await global_events.append_events(
[
InstanceCreated(
instance=instance_value
)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(5.0)
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
assert any([isinstance(x.event, InstanceDeleted) for x in events])
async def test_runner_spinup_timeout(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
):
_, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10
await global_events.append_events(
[
InstanceCreated(
instance=instance_value
)
],
origin=MASTER_NODE_ID
)
await until_event_with_timeout(global_events, RunnerStatusUpdated, multiplicity=3, condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus))
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
assert any([isinstance(x.event, InstanceDeleted) for x in events])

View File

@@ -0,0 +1,258 @@
import asyncio
from logging import Logger
from typing import Callable
import pytest
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.models.model_meta import get_model_meta
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import Host
from shared.types.events import (
InstanceCreated,
InstanceDeleted,
TaskCreated,
)
from shared.types.models import ModelId, ModelMetadata
from shared.types.tasks import ChatCompletionTask, Task, TaskId, TaskStatus, TaskType
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import (
Instance,
InstanceStatus,
ShardAssignments,
)
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import run
from worker.tests.constants import (
COMMAND_1_ID,
COMMAND_2_ID,
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
TASK_2_ID,
)
from worker.tests.test_integration.integration_utils import (
read_streaming_response,
)
from worker.worker import Worker
MODEL_ID = 'mlx-community/Llama-3.3-70B-Instruct-4bit'
@pytest.fixture
async def model_meta() -> ModelMetadata:
return await get_model_meta(MODEL_ID)
async def test_2_runner_inference(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
shard_downloader = NoopShardDownloader()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker1, logger))
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker2, logger))
## Instance
model_id = ModelId(MODEL_ID)
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = 'Can you explain to me how a bubble sort works, speaking as if you are a fairy.'
task.task_params.max_tokens = 1000
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'swap' in response_string.lower()
idx = await global_events.get_last_idx()
await asyncio.sleep(1.0)
events = await global_events.get_events_since(idx)
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)
async def test_parallel_inference(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
shard_downloader = NoopShardDownloader()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker1, logger))
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker2, logger))
## Instance
model_id = ModelId(MODEL_ID)
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
completion_create_params_1 = ChatCompletionTaskParams(
model="gpt-4",
messages=[ChatCompletionMessage(role="user", content='Tell me a haiku that uses the word "pond".')],
stream=True,
max_tokens=1000
)
task1 = ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=completion_create_params_1
)
completion_create_params_2 = ChatCompletionTaskParams(
model="gpt-4",
messages=[ChatCompletionMessage(role="user", content='Tell me a haiku that uses the word "tree".')],
stream=True,
max_tokens=1000
)
task2 = ChatCompletionTask(
task_id=TASK_2_ID,
command_id=COMMAND_2_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=completion_create_params_2
)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task1.task_id,
task=task1
),
TaskCreated(
task_id=task2.task_id,
task=task2
),
],
origin=MASTER_NODE_ID
)
seen_task_started_1, seen_task_finished_1, response_string_1 = await read_streaming_response(global_events)
incomplete_task = TASK_2_ID if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.COMPLETE else TASK_2_ID
seen_task_started_2, seen_task_finished_2, response_string_2 = await read_streaming_response(global_events, filter_task=incomplete_task)
assert seen_task_started_1
assert seen_task_finished_1
assert seen_task_started_2
assert seen_task_finished_2
print(response_string_1)
print(response_string_2)
assert (
('pond' in response_string_1.lower()) ^ ('pond' in response_string_2.lower())
), "'pond' must appear in exactly one response"
assert (
('tree' in response_string_1.lower()) ^ ('tree' in response_string_2.lower())
), "'tree' must appear in exactly one response"
idx = await global_events.get_last_idx()
await asyncio.sleep(1.0)
events = await global_events.get_events_since(idx)
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)

View File

@@ -1,34 +1,29 @@
import asyncio
import os
from logging import Logger
from typing import Callable, Final
from typing import Callable
import pytest
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import Host, NodeId
from shared.types.common import Host
from shared.types.events import InstanceCreated, InstanceDeleted
from shared.types.models import ModelId
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
from shared.types.worker.runners import FailedRunnerStatus
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import run
from worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
)
from worker.worker import Worker
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
TASK_1_ID: Final = "55555555-5555-4555-8555-555555555555"
TASK_2_ID: Final = "66666666-6666-4666-8666-666666666666"
@pytest.fixture
def user_message() -> str:
@@ -63,7 +58,7 @@ async def check_runner_connection(
global_events=global_events,
)
workers.append(worker1)
task1 = asyncio.create_task(run(worker1))
task1 = asyncio.create_task(run(worker1, logger))
tasks.append(task1)
worker2 = Worker(
@@ -74,7 +69,7 @@ async def check_runner_connection(
global_events=global_events,
)
workers.append(worker2)
task2 = asyncio.create_task(run(worker2))
task2 = asyncio.create_task(run(worker2, logger))
tasks.append(task2)
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
@@ -152,6 +147,16 @@ async def check_runner_connection(
# # not now.
# def test_runner_connection_stress(
# logger: Logger,
# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
# hosts: Callable[[int], list[Host]],
# chat_completion_task: Callable[[InstanceId, str], Task],
# ) -> None:
# total_runs = 100
# successes = 0
# # not now.
# def test_runner_connection_stress(
# logger: Logger,
# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
@@ -161,11 +166,29 @@ async def check_runner_connection(
# total_runs = 100
# successes = 0
# for _ in range(total_runs):
# # Create a fresh event loop for each iteration
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# for _ in range(total_runs):
# # Create a fresh event loop for each iteration
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# result = loop.run_until_complete(check_runner_connection(
# logger=logger,
# pipeline_shard_meta=pipeline_shard_meta,
# hosts=hosts,
# chat_completion_task=chat_completion_task,
# ))
# if result:
# successes += 1
# finally:
# # Cancel all running tasks
# pending = asyncio.all_tasks(loop)
# for task in pending:
# task.cancel()
# try:
# result = loop.run_until_complete(check_runner_connection(
# logger=logger,
@@ -181,10 +204,15 @@ async def check_runner_connection(
# for task in pending:
# task.cancel()
# # Run the event loop briefly to allow cancellation to complete
# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
# # Run the event loop briefly to allow cancellation to complete
# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
# # Close the event loop
# loop.close()
# # Close the event loop
# loop.close()
# print(f"Runner connection successes: {successes} / {total_runs}")
# print(f"Runner connection successes: {successes} / {total_runs}")

View File

@@ -0,0 +1,60 @@
from asyncio.subprocess import Process
from logging import Logger
from typing import Callable
import psutil
import pytest
from shared.models.model_meta import get_model_meta
from shared.types.common import Host
from shared.types.models import ModelMetadata
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import InstanceId, RunnerError
from shared.types.worker.shards import PipelineShardMetadata
from worker.runner.runner_supervisor import RunnerSupervisor
from worker.tests.constants import INSTANCE_1_ID, TASK_1_ID
def get_memory_mb(process: Process) -> float:
"""
Returns the resident set size (RSS) memory usage in MiB for the given process.
"""
ps = psutil.Process(process.pid)
rss_bytes: int = ps.memory_info().rss # type: ignore[attr-defined]
return rss_bytes / (1024 * 1024)
@pytest.fixture
async def model_meta() -> ModelMetadata:
return await get_model_meta('mlx-community/Llama-3.3-70B-Instruct-4bit')
@pytest.mark.asyncio
async def test_supervisor_inference_exception(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
process: Process = supervisor.runner_process
memory = get_memory_mb(process)
assert memory > 30*100
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = 'EXO RUNNER MUST FAIL'
with pytest.raises(RunnerError):
async for _ in supervisor.stream_response(task):
pass
await supervisor.astop()
available_memory_bytes: int = psutil.virtual_memory().available
print(available_memory_bytes // (2**30))
assert available_memory_bytes > 30 * 2**30

View File

@@ -0,0 +1,45 @@
from logging import Logger
from typing import Callable
import pytest
from shared.types.common import Host
from shared.types.tasks import (
Task,
TaskId,
)
from shared.types.worker.common import InstanceId, RunnerError
from shared.types.worker.shards import PipelineShardMetadata
from worker.runner.runner_supervisor import RunnerSupervisor
from worker.tests.constants import INSTANCE_1_ID, TASK_1_ID
@pytest.fixture
def user_message():
"""Override the default message to ask about France's capital"""
return "What is the capital of France?"
@pytest.mark.asyncio
async def test_supervisor_single_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = 'EXO RUNNER MUST OOM'
with pytest.raises(RunnerError):
async for _ in supervisor.stream_response(task):
pass
await supervisor.astop()

View File

@@ -0,0 +1,93 @@
import asyncio
from logging import Logger
from typing import Callable
import pytest
from shared.types.common import Host
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import InstanceId, RunnerError
from shared.types.worker.shards import PipelineShardMetadata
from worker.runner.runner_supervisor import RunnerSupervisor
from worker.tests.constants import INSTANCE_1_ID, TASK_1_ID
@pytest.mark.asyncio
async def test_supervisor_instantiation_exception(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
model_shard_meta.immediate_exception = True
with pytest.raises(RunnerError):
await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
@pytest.mark.asyncio
async def test_supervisor_instantiation_timeout(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
model_shard_meta.should_timeout = 10 # timeout after 10s
with pytest.raises(asyncio.TimeoutError):
await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
@pytest.mark.asyncio
async def test_supervisor_inference_exception(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = 'EXO RUNNER MUST FAIL'
with pytest.raises(RunnerError):
async for _ in supervisor.stream_response(task):
pass
@pytest.mark.asyncio
async def test_supervisor_inference_timeout(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = 'EXO RUNNER MUST TIMEOUT'
with pytest.raises(RunnerError):
async for _ in supervisor.stream_response(task):
pass

View File

@@ -1,4 +1,5 @@
import asyncio
import os
import platform
from typing import Any, Callable, Coroutine
@@ -66,6 +67,11 @@ async def start_polling_node_metrics(
# Run heavy FLOPs profiling only if enough time has elapsed
override_memory_env = os.getenv('OVERRIDE_MEMORY')
override_memory: int | None = (
int(override_memory_env) * 2**30 if override_memory_env else None
)
await callback(
NodePerformanceProfile(
model_id=system_info.model_id,
@@ -74,7 +80,7 @@ async def start_polling_node_metrics(
network_interfaces=network_interfaces,
memory=MemoryPerformanceProfile(
ram_total=total_mem,
ram_available=total_mem - used_mem,
ram_available=override_memory if override_memory else total_mem - used_mem,
swap_total=metrics.memory.swap_total
if metrics.memory is not None
and metrics.memory.swap_total is not None

View File

@@ -3,7 +3,6 @@ import logging
import time
from asyncio import Queue
from functools import partial
from time import process_time
from typing import AsyncGenerator, Optional
from shared.db.sqlite import AsyncSQLiteEventStorage
@@ -22,7 +21,6 @@ from shared.types.tasks import TaskId, TaskStatus
from shared.types.worker.common import RunnerId
from shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgressData,
@@ -71,104 +69,116 @@ class Worker:
## Op Executors
async def _execute_assign_op(
self, op: AssignRunnerOp
) -> AsyncGenerator[Event, None]:
'''
A runner has been assigned. We need to also ensure that it's downloaded.
This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state.
'''
self.assigned_runners[op.runner_id] = AssignedRunner(
def _create_assigned_runner(self, op: AssignRunnerOp) -> AssignedRunner:
"""Creates and stores a new AssignedRunner with initial downloading status."""
assigned_runner = AssignedRunner(
runner_id=op.runner_id,
instance_id=op.instance_id,
shard_metadata=op.shard_metadata,
hosts=op.hosts,
status=DownloadingRunnerStatus(
download_progress=DownloadPending(
node_id=self.node_id
)
download_progress=DownloadPending(node_id=self.node_id)
),
runner=None,
)
self.assigned_runners[op.runner_id] = assigned_runner
return assigned_runner
assigned_runner = self.assigned_runners[op.runner_id]
async def _update_runner_status_to_completed_then_inactive(
self, assigned_runner: AssignedRunner
) -> AsyncGenerator[Event, None]:
"""Updates runner status from downloading to completed, then to inactive."""
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadCompleted(node_id=self.node_id)
)
yield assigned_runner.status_update_event()
assigned_runner.status = InactiveRunnerStatus()
yield assigned_runner.status_update_event()
async def _handle_already_downloaded_shard(
self, assigned_runner: AssignedRunner
) -> AsyncGenerator[Event, None]:
"""Handles the case where the shard is already downloaded."""
async for event in self._update_runner_status_to_completed_then_inactive(assigned_runner):
yield event
async def _handle_shard_download_process(
self, assigned_runner: AssignedRunner, op: AssignRunnerOp, initial_progress: RepoDownloadProgress
) -> AsyncGenerator[Event, None]:
"""Manages the shard download process with progress tracking."""
# Set initial ongoing status
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=initial_progress.total_bytes,
downloaded_bytes=initial_progress.downloaded_bytes
)
)
)
yield assigned_runner.status_update_event()
# Set up download progress tracking
download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue()
def download_progress_callback(shard: ShardMetadata, progress: RepoDownloadProgress) -> None:
download_progress_queue.put_nowait(progress)
self.shard_downloader.on_progress(download_progress_callback)
download_task = asyncio.create_task(self.shard_downloader.ensure_shard(op.shard_metadata))
try:
async for event in self._monitor_download_progress(assigned_runner, download_progress_queue):
yield event
finally:
if not download_task.done():
download_task.cancel()
async def _monitor_download_progress(
self, assigned_runner: AssignedRunner, download_progress_queue: asyncio.Queue[RepoDownloadProgress]
) -> AsyncGenerator[Event, None]:
"""Monitors download progress and yields status updates."""
last_progress_time = 0.0
throttle_interval_secs = 1.0
while True:
progress: RepoDownloadProgress = await asyncio.wait_for(download_progress_queue.get(), timeout=15)
if progress.status == "complete":
async for event in self._update_runner_status_to_completed_then_inactive(assigned_runner):
yield event
break
elif progress.status == "in_progress":
if time.monotonic() - last_progress_time > throttle_interval_secs:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=progress.total_bytes,
downloaded_bytes=progress.downloaded_bytes,
)
)
)
yield assigned_runner.status_update_event()
last_progress_time = time.monotonic()
async def _execute_assign_op(
self, op: AssignRunnerOp
) -> AsyncGenerator[Event, None]:
"""
A runner has been assigned. We need to also ensure that it's downloaded.
This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state.
"""
assigned_runner = self._create_assigned_runner(op)
initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata)
if initial_progress.status == "complete":
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadCompleted(
node_id=self.node_id
)
)
yield assigned_runner.status_update_event()
assigned_runner.status = InactiveRunnerStatus()
yield assigned_runner.status_update_event()
return
async for event in self._handle_already_downloaded_shard(assigned_runner):
yield event
else:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=initial_progress.total_bytes,
downloaded_bytes=initial_progress.downloaded_bytes
)
)
)
yield assigned_runner.status_update_event()
# Download it!
# TODO: we probably want download progress as part of a callback that gets passed to the downloader.
download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue()
def download_progress_callback(shard: ShardMetadata, progress: RepoDownloadProgress) -> None:
download_progress_queue.put_nowait(progress)
self.shard_downloader.on_progress(download_progress_callback)
asyncio.create_task(self.shard_downloader.ensure_shard(op.shard_metadata))
# TODO: Dynamic timeout, timeout on no packet update received.
timeout_secs = 10 * 60
start_time = process_time()
last_yield_progress = start_time
while process_time() - start_time < timeout_secs:
progress: RepoDownloadProgress = await download_progress_queue.get()
if progress.status == "complete":
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadCompleted(
node_id=self.node_id,
)
)
yield assigned_runner.status_update_event()
assigned_runner.status = InactiveRunnerStatus()
yield assigned_runner.status_update_event()
break
elif progress.status == "in_progress":
if process_time() - last_yield_progress > 1:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=progress.total_bytes,
downloaded_bytes=progress.downloaded_bytes,
)
)
)
yield assigned_runner.status_update_event()
last_yield_progress = process_time()
else:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadFailed(
node_id=self.node_id,
error_message=f"Timeout downloading model: {op.shard_metadata.model_meta.model_id}"
)
)
yield assigned_runner.status_update_event()
async for event in self._handle_shard_download_process(assigned_runner, op, initial_progress):
yield event
async def _execute_unassign_op(
self, op: UnassignRunnerOp
@@ -193,39 +203,32 @@ class Worker:
) -> AsyncGenerator[Event, None]:
assigned_runner = self.assigned_runners[op.runner_id]
# TODO: This should be dynamic, based on the size of the model.
if not initialize_timeout:
gigabytes_per_second = 10
kilobytes_per_second = gigabytes_per_second * 1024 * 1024
shard = assigned_runner.shard_metadata
weights_size_kb = (shard.end_layer - shard.start_layer) / shard.n_layers * shard.model_meta.storage_size_kilobytes
initialize_timeout = weights_size_kb / kilobytes_per_second + 120.0 # Add a constant 120.0 to ensure connection can be made as well
self.logger.info(f"initialize_timeout: {initialize_timeout}")
try:
assigned_runner.runner = await asyncio.wait_for(
RunnerSupervisor.create(
model_shard_meta=assigned_runner.shard_metadata,
hosts=assigned_runner.hosts,
logger=self.logger,
),
timeout=initialize_timeout,
)
except TimeoutError as e:
import traceback
tb = traceback.format_exc()
e = Exception(f"{type(e).__name__}: {str(e)}. Traceback: {tb}")
async for event in self._fail_runner(e=e, runner_id=op.runner_id):
yield event
return
assigned_runner.runner = await RunnerSupervisor.create(
model_shard_meta=assigned_runner.shard_metadata,
hosts=assigned_runner.hosts,
logger=self.logger,
initialize_timeout=initialize_timeout
)
if assigned_runner.runner.healthy:
assigned_runner.status = LoadedRunnerStatus()
else:
# Log detailed reasons why the runner is not healthy
runner = assigned_runner.runner
health_issues: list[str] = []
if not runner.running:
health_issues.append("runner.running is False")
if runner.runner_process.returncode is not None:
health_issues.append(f"runner_process.returncode is {runner.runner_process.returncode}")
if runner.runner_process.stdin is None:
health_issues.append("runner_process.stdin is None")
elif runner.runner_process.stdin.is_closing():
health_issues.append("runner_process.stdin is closing")
if runner.runner_process.stdout is None:
health_issues.append("runner_process.stdout is None")
self.logger.warning(f"Runner status is not healthy: {', '.join(health_issues)}")
assigned_runner.status = FailedRunnerStatus()
yield self.assigned_runners[op.runner_id].status_update_event()
@@ -251,6 +254,9 @@ class Worker:
'''
assigned_runner = self.assigned_runners[op.runner_id]
if isinstance(assigned_runner.runner, RunnerSupervisor):
await assigned_runner.runner.astop() # astop the runner to ensure it clears out of memory.
assigned_runner.status = FailedRunnerStatus()
yield self.assigned_runners[op.runner_id].status_update_event()
@@ -280,37 +286,30 @@ class Worker:
task_status=TaskStatus.RUNNING,
))
try:
assert assigned_runner.runner is not None
assert assigned_runner.runner.healthy
async for chunk in assigned_runner.runner.stream_response(
task=op.task,
request_started_callback=partial(running_callback, queue)):
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(ChunkGenerated(
# todo: at some point we will no longer have a bijection between task_id and row_id.
# So we probably want to store a mapping between these two in our Worker object.
command_id=chunk.command_id,
chunk=chunk
))
assert assigned_runner.runner is not None
assert assigned_runner.runner.healthy
async for chunk in assigned_runner.runner.stream_response(
task=op.task,
request_started_callback=partial(running_callback, queue)):
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(TaskStateUpdated(
task_id=op.task.task_id,
task_status=TaskStatus.COMPLETE,
await queue.put(ChunkGenerated(
# todo: at some point we will no longer have a bijection between task_id and row_id.
# So we probably want to store a mapping between these two in our Worker object.
command_id=chunk.command_id,
chunk=chunk
))
# After a successful inference:
assigned_runner.status = LoadedRunnerStatus()
await queue.put(assigned_runner.status_update_event())
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(TaskStateUpdated(
task_id=op.task.task_id,
task_status=TaskStatus.COMPLETE,
))
# After a successful inference:
assigned_runner.status = LoadedRunnerStatus()
await queue.put(assigned_runner.status_update_event())
except Exception as e:
# An exception occurs in the runner supervisor
self.logger.warning(f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}')
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
await queue.put(event)
queue: Queue[Event] = asyncio.Queue()
task = asyncio.create_task(inner_execute(queue))
@@ -320,31 +319,31 @@ class Worker:
try:
# Yield items from the queue
# timeout = 30.
timeout = 3.
while True:
item: Event = await asyncio.wait_for(queue.get(), timeout=timeout)
if task.done() and (exception := task.exception()):
raise exception
try:
# Use a timeout to periodically check task status
item: Event = await asyncio.wait_for(queue.get(), timeout=0.01)
except asyncio.TimeoutError:
continue
yield item
timeout = 2.
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus)
):
if isinstance(item.runner_status, LoadedRunnerStatus):
assigned_runner.failures = []
break
except TimeoutError as e:
# Runner supervisor doesn't respond in time; so we put the runner & task into a failed state
self.logger.warning(f'Timed out waiting for runner response to inference task. Task: {op.task}.')
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
yield event
finally:
# Ensure the task is cleaned up
try:
await asyncio.wait_for(task, timeout=5)
except asyncio.TimeoutError:
self.logger.warning("Timed out waiting for task cleanup after inference execution.")
## Operation Planner
@@ -368,7 +367,7 @@ class Worker:
yield event
async def _fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]:
async def fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]:
if runner_id in self.assigned_runners:
assigned_runner = self.assigned_runners[runner_id]
@@ -383,15 +382,15 @@ class Worker:
# Reset failure count back to 0 when succesful
if len(assigned_runner.failures) >= 3:
# Too many retries. We will emit a DeleteInstance
# Too many retries. We will emit a DeleteInstance
yield InstanceDeleted(
instance_id=assigned_runner.instance_id
)
yield assigned_runner.status_update_event()
async def _fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]:
async def fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]:
if runner_id in self.assigned_runners:
yield TaskStateUpdated(
task_id=task_id,
@@ -404,7 +403,7 @@ class Worker:
error_message=str(e)
)
async for event in self._fail_runner(e, runner_id):
async for event in self.fail_runner(e, runner_id):
yield event