mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 21:41:21 -05:00
Compare commits
7 Commits
ciaran/ima
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
922e8075d3 | ||
|
|
6ee745246d | ||
|
|
cee48f6f34 | ||
|
|
2b67e84a03 | ||
|
|
7204fdeb4a | ||
|
|
ec345a4315 | ||
|
|
9967dfa734 |
@@ -45,8 +45,8 @@ struct EXOApp: App {
|
||||
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
|
||||
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
// Remove old LaunchDaemon components if they exist (from previous versions)
|
||||
cleanupLegacyNetworkSetup()
|
||||
// Install LaunchDaemon to disable Thunderbolt Bridge on startup (prevents network loops)
|
||||
NetworkSetupHelper.promptAndInstallIfNeeded()
|
||||
// Check local network access periodically (warning disappears when user grants permission)
|
||||
localNetwork.startPeriodicChecking(interval: 10)
|
||||
controller.scheduleLaunch(after: 15)
|
||||
@@ -136,36 +136,6 @@ struct EXOApp: App {
|
||||
}
|
||||
}
|
||||
|
||||
private func cleanupLegacyNetworkSetup() {
|
||||
guard NetworkSetupHelper.hasInstalledComponents() else { return }
|
||||
// Dispatch async to ensure app is ready before showing alert
|
||||
DispatchQueue.main.async {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to configure local network discovery on your device. This requires granting permission once."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Continue")
|
||||
alert.addButton(withTitle: "Later")
|
||||
|
||||
let response = alert.runModal()
|
||||
guard response == .alertFirstButtonReturn else {
|
||||
Logger().info("User deferred legacy network setup cleanup")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try NetworkSetupHelper.uninstall()
|
||||
Logger().info("Cleaned up legacy network setup components")
|
||||
} catch {
|
||||
// Non-fatal: user may have cancelled admin prompt or cleanup may have
|
||||
// partially succeeded. The app will continue normally.
|
||||
Logger().warning(
|
||||
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for managing EXO's launch-at-login registration
|
||||
|
||||
@@ -11,6 +11,68 @@ enum NetworkSetupHelper {
|
||||
private static let legacyScriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1786
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
/// Prompts user and installs the LaunchDaemon if not already installed.
|
||||
/// Shows an alert explaining what will be installed before requesting admin privileges.
|
||||
static func promptAndInstallIfNeeded() {
|
||||
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
|
||||
Task.detached(priority: .utility) {
|
||||
// If already correctly installed, skip
|
||||
if daemonAlreadyInstalled() {
|
||||
return
|
||||
}
|
||||
|
||||
// Show alert on main thread
|
||||
let shouldInstall = await MainActor.run {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Install")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
return alert.runModal() == .alertFirstButtonReturn
|
||||
}
|
||||
|
||||
guard shouldInstall else {
|
||||
logger.info("User deferred network setup daemon installation")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes all EXO network setup components from the system.
|
||||
/// This includes the LaunchDaemon, scripts, logs, and network location.
|
||||
@@ -30,6 +92,100 @@ enum NetworkSetupHelper {
|
||||
return scriptExists || legacyScriptExists || plistExists
|
||||
}
|
||||
|
||||
private static func daemonAlreadyInstalled() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
|
||||
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let interval = plist["StartInterval"] as? Int,
|
||||
interval == requiredStartInterval
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private static func installLaunchDaemon() throws {
|
||||
let installerScript = makeInstallerScript()
|
||||
try runShellAsAdmin(installerScript)
|
||||
}
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LOG_OUT="/var/log/\(daemonLabel).log"
|
||||
LOG_ERR="/var/log/\(daemonLabel).err.log"
|
||||
|
||||
# First, completely remove any existing installation
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
rm -f "$PLIST_DEST"
|
||||
rm -f "$SCRIPT_DEST"
|
||||
rm -f "$LEGACY_SCRIPT_DEST"
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Install fresh
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func makeUninstallScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -53,6 +53,7 @@ class Node:
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
await router.register_topic(topics.ELECTION_MESSAGES)
|
||||
await router.register_topic(topics.CONNECTION_MESSAGES)
|
||||
await router.register_topic(topics.STATE_CATCHUP)
|
||||
await router.register_topic(topics.DOWNLOAD_COMMANDS)
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
@@ -82,6 +83,7 @@ class Node:
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
|
||||
)
|
||||
else:
|
||||
api = None
|
||||
@@ -94,6 +96,7 @@ class Node:
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
@@ -107,6 +110,7 @@ class Node:
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
state_catchup_sender=router.sender(topics.STATE_CATCHUP),
|
||||
)
|
||||
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
@@ -189,6 +193,7 @@ class Node:
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
state_catchup_sender=self.router.sender(topics.STATE_CATCHUP),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif (
|
||||
@@ -235,6 +240,9 @@ class Node:
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
state_catchup_receiver=self.router.receiver(
|
||||
topics.STATE_CATCHUP
|
||||
),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
|
||||
@@ -166,6 +166,7 @@ class API:
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
election_receiver: Receiver[ElectionMessage],
|
||||
state_catchup_receiver: Receiver[State],
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log: list[Event] = []
|
||||
@@ -173,6 +174,7 @@ class API:
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.state_catchup_receiver = state_catchup_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
@@ -1249,6 +1251,7 @@ class API:
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
tg.start_soon(self._state_catchup)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
@@ -1259,6 +1262,37 @@ class API:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _state_catchup(self):
|
||||
with self.state_catchup_receiver as states:
|
||||
async for state in states:
|
||||
if (
|
||||
self.state.last_event_applied_idx == -1
|
||||
and state.last_event_applied_idx > self.state.last_event_applied_idx
|
||||
):
|
||||
# DEBUG: Log buffer state BEFORE clearing
|
||||
logger.warning(
|
||||
f"STATE_CATCHUP: About to catch up. "
|
||||
f"Current buffer indices: {sorted(self.event_buffer.store.keys())}, "
|
||||
f"next_idx_to_release: {self.event_buffer.next_idx_to_release}, "
|
||||
f"catching up to idx: {state.last_event_applied_idx}"
|
||||
)
|
||||
|
||||
new_idx = state.last_event_applied_idx + 1
|
||||
self.event_buffer.next_idx_to_release = new_idx
|
||||
# Preserve events that arrived early but are still valid (idx >= new_idx)
|
||||
# Remove stale events (idx < new_idx) to prevent memory growth
|
||||
self.event_buffer.store = {
|
||||
k: v for k, v in self.event_buffer.store.items() if k >= new_idx
|
||||
}
|
||||
self.state = state
|
||||
|
||||
# DEBUG: Log buffer state AFTER clearing
|
||||
logger.warning(
|
||||
f"STATE_CATCHUP: Catchup complete. "
|
||||
f"Buffer preserved indices: {sorted(self.event_buffer.store.keys())}, "
|
||||
f"new next_idx_to_release: {self.event_buffer.next_idx_to_release}"
|
||||
)
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -68,6 +68,8 @@ class Master:
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
# not a fan but - send the entire state to a node so it can catchup without the whole event log.
|
||||
state_catchup_sender: Sender[State],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
@@ -77,6 +79,7 @@ class Master:
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.state_catchup_sender = state_catchup_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
@@ -84,7 +87,6 @@ class Master:
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
# TODO: not have this
|
||||
self._event_log: list[Event] = []
|
||||
|
||||
async def run(self):
|
||||
@@ -291,11 +293,17 @@ class Master:
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
if command.since_idx == 0:
|
||||
# This is an optimization, and should not be relied upon in theory.
|
||||
logger.info(
|
||||
f"Master sending catchup state for index {self.state.last_event_applied_idx}"
|
||||
)
|
||||
await self.state_catchup_sender.send(self.state)
|
||||
else:
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
)
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -257,7 +257,13 @@ def _find_ip_prioritised(
|
||||
ip_to_type = {
|
||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||
}
|
||||
priority = {"ethernet": 0, "wifi": 1, "unknown": 2, "thunderbolt": 3}
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
@@ -47,6 +48,7 @@ async def test_master():
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
st_s, _st_r = channel[State]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -67,6 +69,7 @@ async def test_master():
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
state_catchup_sender=st_s,
|
||||
)
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -7,6 +7,7 @@ from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
@@ -45,6 +46,7 @@ ELECTION_MESSAGES = TypedTopic(
|
||||
CONNECTION_MESSAGES = TypedTopic(
|
||||
"connection_messages", PublishPolicy.Never, ConnectionMessage
|
||||
)
|
||||
STATE_CATCHUP = TypedTopic("state_catchup", PublishPolicy.Always, State)
|
||||
DOWNLOAD_COMMANDS = TypedTopic(
|
||||
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
ecpu_usage: float = 0.0
|
||||
|
||||
|
||||
InterfaceType = Literal["wifi", "ethernet", "thunderbolt", "unknown"]
|
||||
InterfaceType = Literal["wifi", "ethernet", "maybe_ethernet", "thunderbolt", "unknown"]
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
|
||||
@@ -400,7 +400,7 @@ class InfoGatherer:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = get_network_interfaces()
|
||||
nics = await get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import socket
|
||||
import sys
|
||||
from subprocess import CalledProcessError, run
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
import psutil
|
||||
from anyio import run_process
|
||||
@@ -16,8 +16,7 @@ async def get_friendly_name() -> str:
|
||||
"""
|
||||
hostname = socket.gethostname()
|
||||
|
||||
# TODO: better non mac support
|
||||
if sys.platform != "darwin": # 'darwin' is the platform name for macOS
|
||||
if sys.platform != "darwin":
|
||||
return hostname
|
||||
|
||||
try:
|
||||
@@ -28,21 +27,20 @@ async def get_friendly_name() -> str:
|
||||
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
|
||||
|
||||
|
||||
def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
async def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
"""Parse networksetup -listallhardwareports to get interface types."""
|
||||
if sys.platform != "darwin":
|
||||
return {}
|
||||
|
||||
try:
|
||||
result = run(
|
||||
["networksetup", "-listallhardwareports"], capture_output=True, text=True
|
||||
)
|
||||
except Exception:
|
||||
result = await run_process(["networksetup", "-listallhardwareports"])
|
||||
except CalledProcessError:
|
||||
return {}
|
||||
|
||||
types: dict[str, InterfaceType] = {}
|
||||
current_type: InterfaceType = "unknown"
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
for line in result.stdout.decode().splitlines():
|
||||
if line.startswith("Hardware Port:"):
|
||||
port_name = line.split(":", 1)[1].strip()
|
||||
if "Wi-Fi" in port_name:
|
||||
@@ -55,12 +53,15 @@ def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
current_type = "unknown"
|
||||
elif line.startswith("Device:"):
|
||||
device = line.split(":", 1)[1].strip()
|
||||
# enX is ethernet adapters or thunderbolt - these must be deprioritised
|
||||
if device.startswith("en") and device not in ["en0", "en1"]:
|
||||
current_type = "maybe_ethernet"
|
||||
types[device] = current_type
|
||||
|
||||
return types
|
||||
|
||||
|
||||
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
async def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
"""
|
||||
Retrieves detailed network interface information on macOS.
|
||||
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
|
||||
@@ -68,7 +69,7 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
Returns a list of NetworkInterfaceInfo objects.
|
||||
"""
|
||||
interfaces_info: list[NetworkInterfaceInfo] = []
|
||||
interface_types = _get_interface_types_from_networksetup()
|
||||
interface_types = await _get_interface_types_from_networksetup()
|
||||
|
||||
for iface, services in psutil.net_if_addrs().items():
|
||||
for service in services:
|
||||
|
||||
@@ -60,9 +60,8 @@ class Worker:
|
||||
connection_message_receiver: Receiver[ConnectionMessage],
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
local_event_sender: Sender[ForwarderEvent],
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
state_catchup_receiver: Receiver[State],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
):
|
||||
@@ -71,6 +70,8 @@ class Worker:
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
self.state_catchup_receiver = state_catchup_receiver
|
||||
self.local_event_index = 0
|
||||
self.event_index_counter = event_index_counter
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
@@ -110,6 +111,7 @@ class Worker:
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
tg.start_soon(self._check_catchup_state)
|
||||
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
@@ -121,13 +123,47 @@ class Worker:
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
async for info in info_stream:
|
||||
await self.event_sender.send(
|
||||
NodeGatheredInfo(
|
||||
node_id=self.node_id,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
info=info,
|
||||
)
|
||||
event = NodeGatheredInfo(
|
||||
node_id=self.node_id,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
info=info,
|
||||
)
|
||||
logger.warning(
|
||||
f"NODE_GATHERED_INFO: Sending event for node {self.node_id}, "
|
||||
f"event_id={event.event_id}"
|
||||
)
|
||||
await self.event_sender.send(event)
|
||||
|
||||
async def _check_catchup_state(self):
|
||||
with self.state_catchup_receiver as states:
|
||||
async for state in states:
|
||||
if (
|
||||
self.state.last_event_applied_idx == -1
|
||||
and state.last_event_applied_idx > self.state.last_event_applied_idx
|
||||
):
|
||||
# DEBUG: Log buffer state BEFORE clearing
|
||||
logger.warning(
|
||||
f"STATE_CATCHUP: About to catch up. "
|
||||
f"Current buffer indices: {sorted(self.event_buffer.store.keys())}, "
|
||||
f"next_idx_to_release: {self.event_buffer.next_idx_to_release}, "
|
||||
f"catching up to idx: {state.last_event_applied_idx}"
|
||||
)
|
||||
|
||||
new_idx = state.last_event_applied_idx + 1
|
||||
self.event_buffer.next_idx_to_release = new_idx
|
||||
# Preserve events that arrived early but are still valid (idx >= new_idx)
|
||||
# Remove stale events (idx < new_idx) to prevent memory growth
|
||||
self.event_buffer.store = {
|
||||
k: v for k, v in self.event_buffer.store.items() if k >= new_idx
|
||||
}
|
||||
self.state = state
|
||||
|
||||
# DEBUG: Log buffer state AFTER clearing
|
||||
logger.warning(
|
||||
f"STATE_CATCHUP: Catchup complete. "
|
||||
f"Buffer preserved indices: {sorted(self.event_buffer.store.keys())}, "
|
||||
f"new next_idx_to_release: {self.event_buffer.next_idx_to_release}"
|
||||
)
|
||||
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
@@ -139,8 +175,20 @@ class Worker:
|
||||
if event_id in self.out_for_delivery:
|
||||
del self.out_for_delivery[event_id]
|
||||
|
||||
# DEBUG: Log what was ingested
|
||||
logger.warning(
|
||||
f"EVENT_APPLIER: Ingested event idx={f_event.origin_idx}, "
|
||||
f"buffer keys now: {sorted(self.event_buffer.store.keys())}"
|
||||
)
|
||||
|
||||
# 2. for each event, apply it to the state
|
||||
indexed_events = self.event_buffer.drain_indexed()
|
||||
|
||||
# DEBUG: Log drain results
|
||||
logger.warning(
|
||||
f"EVENT_APPLIER: Drained {len(indexed_events)} events, "
|
||||
f"next_idx_to_release now: {self.event_buffer.next_idx_to_release}"
|
||||
)
|
||||
if indexed_events:
|
||||
self._nack_attempts = 0
|
||||
|
||||
@@ -157,6 +205,12 @@ class Worker:
|
||||
self._nack_cancel_scope.cancel()
|
||||
|
||||
for idx, event in indexed_events:
|
||||
# DEBUG: Log NodeGatheredInfo events
|
||||
if isinstance(event, NodeGatheredInfo):
|
||||
logger.warning(
|
||||
f"NODE_GATHERED_INFO: Applying event idx={idx} for node {event.node_id}, "
|
||||
f"event_id={event.event_id}"
|
||||
)
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
@@ -318,10 +372,7 @@ class Worker:
|
||||
# We request all events after (and including) the missing index.
|
||||
# This function is started whenever we receive an event that is out of sequence.
|
||||
# It is cancelled as soon as we receiver an event that is in sequence.
|
||||
|
||||
if since_idx < 0:
|
||||
logger.warning(f"Negative value encountered for nack request {since_idx=}")
|
||||
since_idx = 0
|
||||
assert since_idx >= 0
|
||||
|
||||
with CancelScope() as scope:
|
||||
self._nack_cancel_scope = scope
|
||||
|
||||
@@ -240,10 +240,6 @@ def main(
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
@@ -257,10 +253,16 @@ def main(
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
if "glm" in shard_metadata.model_card.model_id.lower():
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
@@ -489,9 +491,10 @@ def get_gpt_oss_encoding():
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
for resp in responses:
|
||||
assert isinstance(resp, GenerationResponse)
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
@@ -501,17 +504,44 @@ def filter_kimi_tokens(
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
current_tool_name: str | None = None
|
||||
tool_arg_parts: list[str] = []
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
recipient = stream.current_recipient
|
||||
|
||||
if recipient != current_tool_name:
|
||||
if current_tool_name is not None:
|
||||
prefix = "functions."
|
||||
if current_tool_name.startswith(prefix):
|
||||
current_tool_name = current_tool_name[len(prefix) :]
|
||||
yield ToolCallResponse(
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name=current_tool_name,
|
||||
arguments="".join(tool_arg_parts).strip(),
|
||||
)
|
||||
]
|
||||
)
|
||||
tool_arg_parts = []
|
||||
break
|
||||
current_tool_name = recipient
|
||||
|
||||
# If inside a tool call, accumulate arguments
|
||||
if current_tool_name is not None:
|
||||
if delta:
|
||||
tool_arg_parts.append(delta)
|
||||
continue
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
@@ -528,13 +558,12 @@ def parse_gpt_oss(
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
break
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse]:
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
@@ -542,6 +571,9 @@ def parse_thinking_models(
|
||||
"""
|
||||
first = True
|
||||
for response in responses:
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
if first:
|
||||
first = False
|
||||
yield response.model_copy(
|
||||
@@ -622,7 +654,7 @@ def _process_image_response(
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
@@ -630,6 +662,7 @@ def parse_tool_calls(
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tool_call_start:
|
||||
in_tool_call = True
|
||||
|
||||
@@ -154,7 +154,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert not isinstance(result, plan_mod.DownloadModel)
|
||||
|
||||
|
||||
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
Reference in New Issue
Block a user