Compare commits

..

10 Commits

Author SHA1 Message Date
Evan
00f25ce239 maybe maybe 2026-01-26 21:29:09 +00:00
Alex Cheema
44453c4c8b Remove change-detection checks from info gatherer monitors (#1283)
## Summary
- When a node times out, its info gets cleared from state. The monitor
functions only sent data when something changed, leaving no mechanism to
re-populate this info after a timeout.
- Removes change-detection checks from `_monitor_misc`,
`_monitor_system_profiler_thunderbolt_data`, `_watch_system_info`, and
`_monitor_thunderbolt_bridge_status` so data is sent periodically
regardless of whether it changed.

## Test plan
- [ ] Verify type checker passes: `uv run basedpyright`
- [ ] Verify linter passes: `uv run ruff check`
- [ ] Verify tests pass: `uv run pytest`
- [ ] Manually test that node info is re-populated after a timeout by
observing cluster behavior

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 12:23:22 +00:00
Jake Hillion
1290e8ed9f dashboard: fix prettier-svelte rebuilding on every file change
The prettier-svelte package was rebuilding whenever any file in the
repository changed because dashboardStubSrc referenced inputs.self
directly. Since inputs.self's store path hash is computed from the
entire repository contents, any file modification invalidated the
derivation.

Added dashboardLockfileSrc using lib.cleanSourceWith to filter
inputs.self to only include package.json and package-lock.json from
the dashboard directory. Updated dashboardStubSrc to reference this
filtered source instead of inputs.self directly.

This ensures prettier-svelte only rebuilds when the lockfiles actually
change, significantly improving build caching for unrelated changes.

Test plan:
- Built prettier-svelte with nix build .#prettier-svelte
- Modified src/exo/main.py and rebuilt - same store path (no rebuild)
- Modified dashboard/package.json and rebuilt - different store path (rebuild triggered)
- Ran nix flake check successfully
2026-01-26 12:02:05 +00:00
Evan Quiney
d93db3d6bf re enable the evil network script (#1277)
seems like we still need the interfaces to be routable for mdns. at
least we're not dependent on this behaviour anymore.
2026-01-24 13:36:06 +00:00
Alex Cheema
ff4a2022f7 Revert state compaction (#1259) (#1275)
## Summary

Reverts the state compaction feature (#1259) to investigate issues with
nodes staying as "unknown" after joining a cluster.

## Test plan

- [ ] Verify nodes properly show up after joining cluster
- [ ] Verify state catchup works correctly without compaction

🤖 Generated with [Claude Code](https://claude.com/claude-code)
2026-01-23 16:29:48 -08:00
rltakashige
cee48f6f34 Parse GPT OSS tool calling (#1271)
## Motivation

<img width="3162" height="858" alt="image"
src="https://github.com/user-attachments/assets/e552f373-620a-4522-894b-6f93fd7f1e50"
/>

## Changes

OpenAI Harmony StreamableParser does parsing for us.

## Why It Works

<img width="3230" height="588" alt="image"
src="https://github.com/user-attachments/assets/81f8a43e-c04b-4bd0-9fd0-65e9b5f6ea1d"
/>
2026-01-23 20:43:53 +00:00
Evan Quiney
2b67e84a03 state compaction (#1259)
## motivation

a node joining a long-running cluster would bring down networking. this
attempts to mitigate that issue by compacting the state for catching up
new devices

## changes

introduces a new topic ("state_catchup") over which a full state can be
sent. currently the master sends the worker + api this new state, and
they update only if they have no other events applied - otherwise usual
NACK systems function

## testing

manually tested on two and eight nodes - its an improvement, not a fix

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-23 20:32:49 +00:00
Alex Cheema
7204fdeb4a Restore Thunderbolt Bridge LaunchDaemon (#1270)
## Motivation

The LaunchDaemon approach for disabling Thunderbolt Bridge was removed
in commit 43f12f5d and replaced with dynamic cycle detection. However,
the LaunchDaemon runs automatically on reboot, ensuring the bridge is
always disabled before it can cause packet storms.

## Changes

- Restore `NetworkSetupHelper.promptAndInstallIfNeeded()` to install a
LaunchDaemon that disables Thunderbolt Bridge on startup
- Show user prompt explaining what will be installed before requesting
admin password
- Remove old cleanup-only logic from `EXOApp.swift`
- Installer removes any existing installation before installing fresh
(handles upgrades)

## Why It Works

The LaunchDaemon runs at boot with `RunAtLoad=true` and periodically
(every ~30 min), destroying bridge0 and disabling Thunderbolt Bridge
before it can cause packet storms. The daemon is only installed
once—`daemonAlreadyInstalled()` checks script content and plist config
match before prompting.

## Test Plan

### Manual Testing
- Run app first time → should see prompt → click Install → enter admin
password → daemon installed
- Run app again → no prompt (already installed)
- Reboot → bridge0 should be destroyed/disabled automatically
- Check daemon: `launchctl list | grep io.exo.networksetup`
- Check files: `/Library/LaunchDaemons/io.exo.networksetup.plist`,
`/Library/Application Support/EXO/disable_bridge.sh`

### Automated Testing
N/A - requires admin privileges and system-level changes

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 20:25:37 +00:00
Evan Quiney
ec345a4315 fix: deprioritise uncertain ethernet devices (#1267)
we were placing coordinators on uncertain devices (enX+) that are listed
as "USB LAN" - these could be thunderbolt ports breaking RDMA instances
2026-01-23 20:13:28 +00:00
ciaranbor
9967dfa734 Prevent conversation collision (#1266)
## Motivation

When a user switched conversations while a response was still streaming,
the streaming content would be written to the currently selected
conversation instead of the original one. For streamed image generation,
each partial image would be written to the open conversation

## Changes

Added helper methods to track and update the correct conversation during
streaming:
- updateConversationMessage() - Update a message in a specific
conversation by ID
- syncActiveMessagesIfNeeded() - Sync this.messages from target
conversation only if it's active
- conversationExists() - Check if a conversation still exists (handles
mid-stream deletion)
  - persistConversation() - Persist a specific conversation to storage
- addMessageToConversation() - Add a message directly to a specific
conversation


## Why It Works

Capturing the conversation ID at the start of the request ensures we
know which conversation to update

## Test Plan

### Manual Testing

Tested switching conversation during generation across each model type
2026-01-23 19:59:08 +00:00
14 changed files with 1023 additions and 442 deletions

23
Cargo.lock generated
View File

@@ -514,6 +514,20 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
[[package]]
name = "cluster_membership"
version = "0.0.1"
dependencies = [
"anyhow",
"async-trait",
"futures-lite",
"futures-timer",
"libp2p",
"log",
"tokio",
"tracing-subscriber",
]
[[package]]
name = "colorchoice"
version = "1.0.4"
@@ -1030,6 +1044,12 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "ff"
version = "0.13.1"
@@ -1138,7 +1158,10 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]

View File

@@ -3,7 +3,7 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/util",
"rust/util", "rust/cluster_membership",
]
[workspace.package]
@@ -62,6 +62,7 @@ frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-lite = "2.6.1"
futures-util = "0.3"
futures-timer = "3.0"

View File

@@ -18,9 +18,6 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 30
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -34,6 +31,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

View File

@@ -3,12 +3,28 @@
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to ONLY include package.json and package-lock.json
# This ensures prettier-svelte only rebuilds when lockfiles change
dashboardLockfileSrc = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
isDashboardDir = baseName == "dashboard" && type == "directory";
isPackageFile =
(lib.hasInfix "/dashboard/" path || lib.hasSuffix "/dashboard" (builtins.dirOf path))
&& (baseName == "package.json" || baseName == "package-lock.json");
in
isDashboardDir || isPackageFile;
};
# Stub source with lockfiles and minimal files for build to succeed
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
mkdir -p $out
cp ${inputs.self}/dashboard/package.json $out/
cp ${inputs.self}/dashboard/package-lock.json $out/
cp ${dashboardLockfileSrc}/dashboard/package.json $out/
cp ${dashboardLockfileSrc}/dashboard/package-lock.json $out/
# Minimal files so vite build succeeds (produces empty output)
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
mkdir -p $out/src

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
[package]
name = "cluster_membership"
version.workspace = true
edition.workspace = true
publish = false
[dependencies]
# util
anyhow.workspace = true
log.workspace = true
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
# async
tokio = { workspace = true, features = ["full"] }
futures-timer = { workspace = true }
futures-lite = "2.6.1"
# networking
libp2p = { workspace = true, features = ["full"] }
async-trait = "0.1.89"
[lints]
workspace = true

View File

@@ -0,0 +1,30 @@
use cluster_membership::Peer;
use libp2p::identity::ed25519::SecretKey;
use tokio::io::{self, AsyncBufReadExt};
use tracing_subscriber::{EnvFilter, filter::LevelFilter};
#[tokio::main]
async fn main() {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init();
let (mut peer, send, mut recv) =
Peer::new(SecretKey::generate(), "hello".to_string()).expect("peer should always build");
let ch = peer.subscribe("chatroom".to_string());
let jh = tokio::spawn(async move { peer.run().await });
let mut stdin = io::BufReader::new(io::stdin()).lines();
loop {
tokio::select! {
Ok(Some(line)) = stdin.next_line() => {send.send((ch.clone(), line.into_bytes())).await.expect("example");}
Some(r) = recv.recv() => match r {
Ok((_, id, line)) => println!("{:?}:{:?}", id, String::from_utf8_lossy(&line)),
Err(e) => eprintln!("{e:?}"),
},
else => break
}
}
jh.await.expect("task failure");
}

View File

@@ -0,0 +1,220 @@
use libp2p::{
Multiaddr, PeerId, Swarm, SwarmBuilder,
futures::StreamExt,
gossipsub::{self, PublishError, Sha256Topic, TopicHash},
identify,
identity::{Keypair, ed25519},
mdns,
swarm::{NetworkBehaviour, SwarmEvent, dial_opts::DialOpts},
};
use std::{
collections::HashMap,
time::{Duration, Instant},
};
use tokio::{select, sync::mpsc};
const DEFAULT_BUFFER_SIZE: usize = 10;
const MDNS_IGNORE_DURATION_SECS: u64 = 30;
impl Peer {
pub fn new(
identity: ed25519::SecretKey,
namespace: String,
) -> anyhow::Result<(
Self,
mpsc::Sender<(TopicHash, Vec<u8>)>,
mpsc::Receiver<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
)> {
let mut id_bytes = identity.as_ref().to_vec();
let mut swarm =
SwarmBuilder::with_existing_identity(Keypair::ed25519_from_bytes(&mut id_bytes)?)
.with_tokio()
.with_quic()
// TODO(evan): .with_bandwidth_metrics();
.with_behaviour(|kp| Behaviour::new(kp, namespace.clone()))?
.build();
swarm.listen_on("/ip6/::/udp/0/quic-v1".parse()?)?;
swarm.listen_on("/ip4/0.0.0.0/udp/0/quic-v1".parse()?)?;
let (to_swarm, from_client) = mpsc::channel(DEFAULT_BUFFER_SIZE);
let (to_client, from_swarm) = mpsc::channel(DEFAULT_BUFFER_SIZE);
Ok((
Self {
swarm,
namespace,
denied: HashMap::new(),
from_client,
to_client,
},
to_swarm,
from_swarm,
))
}
pub fn subscribe(&mut self, topic: String) -> TopicHash {
let topic = Sha256Topic::new(topic);
self.swarm
.behaviour_mut()
.gossipsub
.subscribe(&topic)
.expect("topic filtered");
topic.hash()
}
pub async fn run(&mut self) {
loop {
select! {
ev = self.swarm.select_next_some() => {
let Ok(()) = self.handle_swarm_event(ev).await else {
return
};
},
Some(msg) = self.from_client.recv() => {
if let Err(e) = self.swarm.behaviour_mut().gossipsub.publish(msg.0, msg.1) {
let Ok(()) = self.to_client.send(Err(e)).await else {
return
};
}
},
}
}
}
async fn handle_swarm_event(&mut self, event: SwarmEvent<BehaviourEvent>) -> Result<(), ()> {
let SwarmEvent::Behaviour(event) = event else {
if let SwarmEvent::NewListenAddr {
listener_id: _,
address,
} = event
{
log::info!("new listen address {address}")
}
return Ok(());
};
match event {
BehaviourEvent::Mdns(mdns_event) => match mdns_event {
mdns::Event::Discovered(vec) => {
// Dial everyone
let mut addrs = HashMap::<PeerId, Vec<Multiaddr>>::new();
vec.into_iter()
.filter(|(peer_id, _)| {
self.denied.get(peer_id).is_none_or(|t| {
t.elapsed() > Duration::from_secs(MDNS_IGNORE_DURATION_SECS)
})
})
.for_each(|(peer_id, addr)| addrs.entry(peer_id).or_default().push(addr));
addrs.into_iter().for_each(|(peer_id, addrs)| {
let _ = self
.swarm
.dial(DialOpts::peer_id(peer_id).addresses(addrs).build());
});
}
mdns::Event::Expired(vec) => {
vec.iter().for_each(|(peer_id, _)| {
log::debug!("{peer_id} no longer reachable on mDNS");
self.swarm
.behaviour_mut()
.gossipsub
.remove_explicit_peer(peer_id);
});
}
},
BehaviourEvent::Identify(identify::Event::Received {
connection_id: _,
peer_id,
info,
}) => {
if info
.protocols
.iter()
.any(|p| p.as_ref().contains(&self.namespace))
{
self.passed_namespace(peer_id);
} else {
self.failed_namespace(peer_id);
}
}
BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: _,
message_id: _,
message:
gossipsub::Message {
topic,
data,
source: Some(source_peer),
..
},
}) => {
let Ok(()) = self.to_client.send(Ok((topic, source_peer, data))).await else {
return Err(());
};
}
_ => {}
}
Ok(())
}
fn passed_namespace(&mut self, peer: PeerId) {
log::info!("new peer {peer:?}");
self.denied.remove(&peer);
self.swarm
.behaviour_mut()
.gossipsub
.remove_blacklisted_peer(&peer);
self.swarm
.behaviour_mut()
.gossipsub
.add_explicit_peer(&peer);
}
fn failed_namespace(&mut self, peer: PeerId) {
log::debug!("{peer} failed handshake");
self.denied.insert(peer, Instant::now());
self.swarm.behaviour_mut().gossipsub.blacklist_peer(&peer);
// we don't care if disconnect fails
let _ = self.swarm.disconnect_peer_id(peer);
}
}
pub struct Peer {
pub swarm: Swarm<Behaviour>,
denied: HashMap<PeerId, Instant>,
namespace: String,
to_client: mpsc::Sender<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
from_client: mpsc::Receiver<(TopicHash, Vec<u8>)>,
}
#[derive(NetworkBehaviour)]
pub struct Behaviour {
mdns: mdns::tokio::Behaviour,
pub gossipsub: gossipsub::Behaviour,
identify: identify::Behaviour,
}
impl Behaviour {
fn new(kp: &Keypair, namespace: String) -> Self {
let mdns = mdns::tokio::Behaviour::new(Default::default(), kp.public().to_peer_id())
.expect("implementation is infallible");
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(kp.clone()),
gossipsub::ConfigBuilder::default()
.max_transmit_size(1024 * 1024)
.protocol_id_prefix(format!("/exo/gossip/{namespace}/v1"))
.build()
.expect("fixed gossipsub config should always build"),
)
.expect("fixed gossipsub init should always build");
let identify = identify::Behaviour::new(
identify::Config::new_with_signed_peer_record(format!("/exo/identity/v1"), kp)
.with_push_listen_addr_updates(true),
);
Behaviour {
mdns,
gossipsub,
identify,
}
}
}

View File

@@ -257,7 +257,13 @@ def _find_ip_prioritised(
ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces
}
priority = {"ethernet": 0, "wifi": 1, "unknown": 2, "thunderbolt": 3}
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))

View File

@@ -48,7 +48,7 @@ class SystemPerformanceProfile(CamelCaseModel):
ecpu_usage: float = 0.0
InterfaceType = Literal["wifi", "ethernet", "thunderbolt", "unknown"]
InterfaceType = Literal["wifi", "ethernet", "maybe_ethernet", "thunderbolt", "unknown"]
class NetworkInterfaceInfo(CamelCaseModel):

View File

@@ -349,13 +349,8 @@ class InfoGatherer:
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await self.info_sender.send(await MiscData.gather())
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler_thunderbolt_data(self):
@@ -365,15 +360,12 @@ class InfoGatherer:
if iface_map is None:
return
old_idents = []
while True:
data = await ThunderboltConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
old_idents = idents
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
@@ -398,22 +390,17 @@ class InfoGatherer:
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
nics = await get_network_interfaces()
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_thunderbolt_bridge_status(self):
if self.thunderbolt_bridge_poll_interval is None:
return
prev: ThunderboltBridgeInfo | None = None
while True:
curr = await ThunderboltBridgeInfo.gather()
if curr is not None and prev != curr:
prev = curr
if curr is not None:
await self.info_sender.send(curr)
await anyio.sleep(self.thunderbolt_bridge_poll_interval)

View File

@@ -1,6 +1,6 @@
import socket
import sys
from subprocess import CalledProcessError, run
from subprocess import CalledProcessError
import psutil
from anyio import run_process
@@ -16,8 +16,7 @@ async def get_friendly_name() -> str:
"""
hostname = socket.gethostname()
# TODO: better non mac support
if sys.platform != "darwin": # 'darwin' is the platform name for macOS
if sys.platform != "darwin":
return hostname
try:
@@ -28,21 +27,20 @@ async def get_friendly_name() -> str:
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
async def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
"""Parse networksetup -listallhardwareports to get interface types."""
if sys.platform != "darwin":
return {}
try:
result = run(
["networksetup", "-listallhardwareports"], capture_output=True, text=True
)
except Exception:
result = await run_process(["networksetup", "-listallhardwareports"])
except CalledProcessError:
return {}
types: dict[str, InterfaceType] = {}
current_type: InterfaceType = "unknown"
for line in result.stdout.splitlines():
for line in result.stdout.decode().splitlines():
if line.startswith("Hardware Port:"):
port_name = line.split(":", 1)[1].strip()
if "Wi-Fi" in port_name:
@@ -55,12 +53,15 @@ def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
current_type = "unknown"
elif line.startswith("Device:"):
device = line.split(":", 1)[1].strip()
# enX is ethernet adapters or thunderbolt - these must be deprioritised
if device.startswith("en") and device not in ["en0", "en1"]:
current_type = "maybe_ethernet"
types[device] = current_type
return types
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
async def get_network_interfaces() -> list[NetworkInterfaceInfo]:
"""
Retrieves detailed network interface information on macOS.
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
@@ -68,7 +69,7 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
Returns a list of NetworkInterfaceInfo objects.
"""
interfaces_info: list[NetworkInterfaceInfo] = []
interface_types = _get_interface_types_from_networksetup()
interface_types = await _get_interface_types_from_networksetup()
for iface, services in psutil.net_if_addrs().items():
for service in services:

View File

@@ -240,10 +240,6 @@ def main(
prompt=prompt,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# For other thinking models (GLM, etc.), check if we need to
# prepend the thinking tag that was consumed by the chat template
if detect_thinking_prompt_suffix(prompt, tokenizer):
@@ -257,10 +253,16 @@ def main(
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
if "glm" in shard_metadata.model_card.model_id.lower():
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
@@ -489,9 +491,10 @@ def get_gpt_oss_encoding():
def filter_kimi_tokens(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
@@ -501,17 +504,44 @@ def filter_kimi_tokens(
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
current_tool_name: str | None = None
tool_arg_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
yield ToolCallResponse(
tool_calls=[
ToolCallItem(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
]
)
tool_arg_parts = []
break
current_tool_name = recipient
# If inside a tool call, accumulate arguments
if current_tool_name is not None:
if delta:
tool_arg_parts.append(delta)
continue
if ch == "analysis" and not thinking:
thinking = True
@@ -528,13 +558,12 @@ def parse_gpt_oss(
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
def parse_thinking_models(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -542,6 +571,9 @@ def parse_thinking_models(
"""
first = True
for response in responses:
if isinstance(response, ToolCallResponse):
yield response
continue
if first:
first = False
yield response.model_copy(
@@ -622,7 +654,7 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
@@ -630,6 +662,7 @@ def parse_tool_calls(
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True

View File

@@ -154,7 +154,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
tasks={},
)
assert result is None
assert not isinstance(result, plan_mod.DownloadModel)
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():