Compare commits

..

5 Commits

Author SHA1 Message Date
Alex Cheema
daf2f9f48e fix: keep TRUST_REMOTE_CODE=True for built-in models
The constant is the default for built-in models with known model cards,
which are trusted. Custom models added via API already default to
trust_remote_code=False in ModelCard.fetch_from_hf(). The CLI flag
overrides custom models only.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 13:21:56 -08:00
Alex Cheema
5617aed345 feat: add --trust-remote-code CLI flag for custom model tokenizers
Some custom models (e.g. Kimi) require trust_remote_code=True to load
their tokenizers. This adds an opt-in CLI flag that sets an env var
read by runner subprocesses, following the same pattern as --fast-synch.
The flag is intentionally CLI-only (not API-accessible) to prevent
remote code execution attacks via the API.

Also changes the default TRUST_REMOTE_CODE constant from True to False,
making remote code execution fully opt-in.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 13:17:52 -08:00
rltakashige
365dd68d9a Final fixes for release (#1603)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-23 21:10:15 +00:00
Alex Cheema
d3d129581e test: verify instance deletion cancels ongoing tasks (#1508)
## Summary
- The cancellation logic for issue #1215 already exists in
`get_transition_events()` (`src/exo/master/placement.py:208-227`) — when
an instance is deleted, `TaskStatusUpdated(Cancelled)` events are
emitted for all Pending/Running tasks on that instance
- Combined with PR #1276's token-boundary cancellation in runners, the
full pipeline works end-to-end
- However, the existing test
`test_get_transition_events_delete_instance` passed `{}` for tasks, so
this path was never exercised
- This PR adds 4 tests covering the cancellation behavior:
  - Running tasks are cancelled on instance deletion
  - Pending tasks are cancelled on instance deletion
  - Completed/Failed/TimedOut/Cancelled tasks are left alone
  - Only tasks matching the deleted instance are cancelled

Closes #1215

## Test plan
- [x] `uv run pytest src/exo/master/tests/test_placement.py -v` — all 15
tests pass
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 20:12:23 +00:00
Alex Cheema
c90a0cec78 fix: suppress closure errors in runnersupervisor and force spawn start method (#1547)
some errors could be thrown during shutdown - we can dismiss these safely

co-authored by me :)
2026-02-23 18:30:41 +00:00
14 changed files with 638 additions and 266 deletions

View File

@@ -1,6 +1,7 @@
import asyncio
import socket
from dataclasses import dataclass, field
from random import random
import anyio
from anyio import current_time
@@ -21,9 +22,13 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
# TODO(evan): just for acks, should delete this ASAP
GlobalForwarderEvent,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
@@ -34,28 +39,40 @@ from exo.shared.types.worker.downloads import (
DownloadProgress,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.utils.channels import Receiver, Sender
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.task_group import TaskGroup
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
event_sender: Sender[Event]
local_event_sender: Sender[LocalForwarderEvent]
# ack stuff
_global_event_receiver: Receiver[GlobalForwarderEvent]
_out_for_delivery: dict[EventId, LocalForwarderEvent] = field(default_factory=dict)
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
# Per-model throttle for download progress events
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
@@ -111,7 +128,10 @@ class DownloadCoordinator:
try:
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._clear_ofd)
if not self.offline:
tg.start_soon(self._check_internet_connection)
finally:
@@ -149,6 +169,20 @@ class DownloadCoordinator:
def shutdown(self) -> None:
self._tg.cancel_tasks()
# directly copied from worker
async def _resend_out_for_delivery(self) -> None:
# This can also be massively tightened, we should check events are at least a certain age before resending.
# Exponential backoff would also certainly help here.
while True:
await anyio.sleep(1 + random())
for event in self._out_for_delivery.copy().values():
await self.local_event_sender.send(event)
async def _clear_ofd(self) -> None:
with self._global_event_receiver as events:
async for event in events:
self._out_for_delivery.pop(event.event.event_id, None)
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
@@ -321,6 +355,23 @@ class DownloadCoordinator:
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
self._out_for_delivery[event.event_id] = fe
async def _emit_existing_download_progress(self) -> None:
try:
while True:

View File

@@ -0,0 +1,98 @@
from typing import Any
import anyio
import pytest
from exo.download.coordinator import DownloadCoordinator
from exo.download.shard_downloader import NoopShardDownloader
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
GlobalForwarderEvent,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
DownloadPending,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.utils.channels import channel
# Use the builtin NoopShardDownloader directly it already implements the required abstract interface.
# No additional subclass is needed for this test.
@pytest.mark.anyio
async def test_ack_behaviour():
# Create channels (type Any for simplicity)
_, command_receiver = channel[Any]()
local_sender, _ = channel[Any]()
global_sender, global_receiver = channel[Any]()
# Minimal identifiers
node_id = NodeId()
session_id = SessionId(master_node_id=node_id, election_clock=0)
# Create a dummy model card and shard metadata
model_id = ModelId("test/model")
model_card = ModelCard(
model_id=model_id,
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
# Instantiate the coordinator with the dummy downloader
coord = DownloadCoordinator(
node_id=node_id,
session_id=session_id,
shard_downloader=NoopShardDownloader(),
download_command_receiver=command_receiver,
local_event_sender=local_sender,
_global_event_receiver=global_receiver,
)
async with anyio.create_task_group() as tg:
# Start the forwarding and ackclearing loops
tg.start_soon(coord._forward_events) # pyright: ignore[reportPrivateUsage]
tg.start_soon(coord._clear_ofd) # pyright: ignore[reportPrivateUsage]
# Send a pending download progress event via the internal event sender
pending = DownloadPending(
node_id=node_id,
shard_metadata=shard,
model_directory="/tmp/model",
)
await coord.event_sender.send(NodeDownloadProgress(download_progress=pending))
# Allow the forwarder to process the event
await anyio.sleep(0.1)
# There should be exactly one entry awaiting ACK
assert len(coord._out_for_delivery) == 1 # pyright: ignore[reportPrivateUsage]
# Retrieve the stored LocalForwarderEvent
stored_fe: LocalForwarderEvent = next(iter(coord._out_for_delivery.values())) # pyright: ignore[reportPrivateUsage]
# Simulate receiving a global ack for this event
ack = GlobalForwarderEvent(
origin_idx=0,
origin=node_id,
session=session_id,
event=stored_fe.event,
)
await global_sender.send(ack)
# Give the clearofd task a moment to process the ack
await anyio.sleep(0.1)
# The outfordelivery map should now be empty
assert len(coord._out_for_delivery) == 0 # pyright: ignore[reportPrivateUsage]
# Cancel background tasks
tg.cancel_scope.cancel()

View File

@@ -15,7 +15,6 @@ from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.event_router import EventRouter
from exo.routing.router import Router, get_node_id_keypair
from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult
@@ -30,7 +29,6 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
event_router: EventRouter
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
@@ -54,12 +52,6 @@ class Node:
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
event_router = EventRouter(
session_id,
command_sender=router.sender(topics.COMMANDS),
external_outbound=router.sender(topics.LOCAL_EVENTS),
external_inbound=router.receiver(topics.GLOBAL_EVENTS),
)
logger.info(f"Starting node {node_id}")
@@ -67,10 +59,13 @@ class Node:
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
event_sender=event_router.sender(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
offline=args.offline,
# TODO(evan): remove
_global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
)
else:
download_coordinator = None
@@ -78,8 +73,9 @@ class Node:
if args.spawn_api:
api = API(
node_id,
session_id,
port=args.api_port,
event_receiver=event_router.receiver(),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
@@ -90,8 +86,9 @@ class Node:
if not args.no_worker:
worker = Worker(
node_id,
event_receiver=event_router.receiver(),
event_sender=event_router.sender(),
session_id,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
@@ -102,7 +99,6 @@ class Node:
master = Master(
node_id,
session_id,
event_sender=event_router.sender(),
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
@@ -125,7 +121,6 @@ class Node:
return cls(
router,
event_router,
download_coordinator,
worker,
election,
@@ -141,7 +136,6 @@ class Node:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.event_router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
@@ -189,7 +183,6 @@ class Node:
self.master = Master(
self.node_id,
result.session_id,
event_sender=self.event_router.sender(),
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
@@ -213,24 +206,21 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
self.event_router.shutdown()
self.event_router = EventRouter(
result.session_id,
self.router.sender(topics.COMMANDS),
self.router.receiver(topics.GLOBAL_EVENTS),
self.router.sender(topics.LOCAL_EVENTS),
)
self._tg.start_soon(self.event_router.run)
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
event_sender=self.event_router.sender(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
offline=self.offline,
# TODO(evan): remove
_global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
@@ -238,8 +228,11 @@ class Node:
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
event_receiver=self.event_router.receiver(),
event_sender=self.event_router.sender(),
result.session_id,
global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
@@ -247,7 +240,7 @@ class Node:
)
self._tg.start_soon(self.worker.run)
if self.api:
self.api.reset(result.won_clock, self.event_router.receiver())
self.api.reset(result.session_id, result.won_clock)
else:
if self.api:
self.api.unpause(result.won_clock)
@@ -259,7 +252,7 @@ def main():
target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn")
mp.set_start_method("spawn", force=True)
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
@@ -268,6 +261,13 @@ def main():
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set trust_remote_code override env var for runner subprocesses
if args.trust_remote_code:
os.environ["EXO_TRUST_REMOTE_CODE"] = "1"
logger.warning(
"--trust-remote-code enabled: models may execute arbitrary code during loading"
)
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
@@ -292,6 +292,7 @@ class Args(CamelCaseModel):
no_downloads: bool = False
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
trust_remote_code: bool = False
@classmethod
def parse(cls) -> Self:
@@ -343,6 +344,11 @@ class Args(CamelCaseModel):
action="store_true",
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Allow models to execute custom code during tokenizer loading (security-sensitive, CLI-only)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -140,10 +140,11 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ChunkGenerated,
Event,
GlobalForwarderEvent,
IndexedEvent,
TracesMerged,
)
@@ -171,6 +172,7 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.task_group import TaskGroup
_API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api"
@@ -194,9 +196,10 @@ class API:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
*,
port: int,
event_receiver: Receiver[IndexedEvent],
global_event_receiver: Receiver[GlobalForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
@@ -207,9 +210,11 @@ class API:
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_receiver = event_receiver
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.last_completed_election: int = 0
self.port = port
@@ -249,18 +254,17 @@ class API:
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup = TaskGroup()
def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]):
def reset(self, new_session_id: SessionId, result_clock: int):
logger.info("Resetting API State")
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
self.event_receiver.close()
self.event_receiver = event_receiver
self._tg.start_soon(self._apply_state)
def unpause(self, result_clock: int):
logger.info("Unpausing API")
@@ -1602,7 +1606,7 @@ class API:
finally:
self._event_log.close()
self.command_sender.close()
self.event_receiver.close()
self.global_event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
@@ -1619,33 +1623,38 @@ class API:
)
async def _apply_state(self):
idx = 0
with self.event_receiver as events:
async for event in events:
self._event_log.append(event.event)
self.state = apply(self.state, event)
idx += 1
event = event.event
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
if queue := self._image_generation_queues.get(
event.command_id, None
):
assert isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(event.command_id, None)
if queue := self._text_generation_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
if isinstance(event, ChunkGenerated):
if queue := self._image_generation_queues.get(
event.command_id, None
):
assert isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
if queue := self._text_generation_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
def _save_merged_trace(self, event: TracesMerged) -> None:
traces = [

View File

@@ -60,7 +60,7 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
from exo.utils.task_group import TaskGroup
@@ -72,21 +72,25 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
event_sender: Sender[Event],
local_event_receiver: Receiver[LocalForwarderEvent],
global_event_sender: Sender[GlobalForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id = node_id
self.session_id = session_id
self.state = State()
self._tg: TaskGroup = TaskGroup()
self.node_id = node_id
self.session_id = session_id
self.command_task_mapping: dict[CommandId, TaskId] = {}
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
self.event_sender = event_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
@@ -100,12 +104,15 @@ class Master:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self._event_log.close()
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -402,6 +409,22 @@ class Master:
self._event_log.append(event)
await self._send_event(indexed)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
LocalForwarderEvent(
origin=self._system_id,
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
local_index += 1
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly

View File

@@ -17,7 +17,6 @@ from exo.shared.types.commands import (
)
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
GlobalForwarderEvent,
IndexedEvent,
InstanceCreated,
@@ -51,7 +50,6 @@ async def test_master():
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[LocalForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
ev_send, _ev_recv = channel[Event]()
all_events: list[IndexedEvent] = []
@@ -69,7 +67,6 @@ async def test_master():
master = Master(
node_id,
session_id,
event_sender=ev_send,
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,

View File

@@ -14,10 +14,12 @@ from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.events import InstanceCreated, InstanceDeleted, TaskStatusUpdated
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
Instance,
@@ -456,3 +458,117 @@ def test_tensor_rdma_backend_connectivity_matrix(
else:
ip_part = coordinator.split(":")[0]
assert len(ip_part.split(".")) == 4
def _make_task(
instance_id: InstanceId,
status: TaskStatus = TaskStatus.Running,
) -> TextGeneration:
return TextGeneration(
task_id=TaskId(),
task_status=status,
instance_id=instance_id,
command_id=CommandId(),
task_params=TextGenerationTaskParams(
model=ModelId("test-model"),
input=[InputMessage(role="user", content="hello")],
),
)
def test_get_transition_events_delete_instance_cancels_running_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
task = _make_task(instance_id, TaskStatus.Running)
tasks = {task.task_id: task}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert cancellation event should come before the deletion event
assert len(events) == 2
assert isinstance(events[0], TaskStatusUpdated)
assert events[0].task_id == task.task_id
assert events[0].task_status == TaskStatus.Cancelled
assert isinstance(events[1], InstanceDeleted)
assert events[1].instance_id == instance_id
def test_get_transition_events_delete_instance_cancels_pending_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
task = _make_task(instance_id, TaskStatus.Pending)
tasks = {task.task_id: task}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert
assert len(events) == 2
assert isinstance(events[0], TaskStatusUpdated)
assert events[0].task_id == task.task_id
assert events[0].task_status == TaskStatus.Cancelled
assert isinstance(events[1], InstanceDeleted)
def test_get_transition_events_delete_instance_ignores_completed_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
tasks = {
t.task_id: t
for t in [
_make_task(instance_id, TaskStatus.Complete),
_make_task(instance_id, TaskStatus.Failed),
_make_task(instance_id, TaskStatus.TimedOut),
_make_task(instance_id, TaskStatus.Cancelled),
]
}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert only the InstanceDeleted event, no cancellations
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
def test_get_transition_events_delete_instance_cancels_only_matching_tasks(
instance: Instance,
):
# arrange
instance_id_a = InstanceId()
instance_id_b = InstanceId()
current_instances: dict[InstanceId, Instance] = {
instance_id_a: instance,
instance_id_b: instance,
}
# only delete instance A, keep instance B
target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}
task_a = _make_task(instance_id_a, TaskStatus.Running)
task_b = _make_task(instance_id_b, TaskStatus.Running)
tasks = {task_a.task_id: task_a, task_b.task_id: task_b}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert only task_a should be cancelled
cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]
delete_events = [e for e in events if isinstance(e, InstanceDeleted)]
assert len(cancel_events) == 1
assert cancel_events[0].task_id == task_a.task_id
assert cancel_events[0].task_status == TaskStatus.Cancelled
assert len(delete_events) == 1
assert delete_events[0].instance_id == instance_id_a

View File

@@ -1,161 +0,0 @@
from dataclasses import dataclass, field
from random import random
import anyio
from anyio import BrokenResourceError, ClosedResourceError
from anyio.abc import CancelScope
from loguru import logger
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
GlobalForwarderEvent,
IndexedEvent,
LocalForwarderEvent,
)
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.task_group import TaskGroup
@dataclass
class EventRouter:
session_id: SessionId
command_sender: Sender[ForwarderCommand]
external_inbound: Receiver[GlobalForwarderEvent]
external_outbound: Sender[LocalForwarderEvent]
_system_id: SystemId = field(init=False, default_factory=SystemId)
internal_outbound: list[Sender[IndexedEvent]] = field(
init=False, default_factory=list
)
event_buffer: OrderedBuffer[Event] = field(
init=False, default_factory=OrderedBuffer
)
out_for_delivery: dict[EventId, tuple[float, LocalForwarderEvent]] = field(
init=False, default_factory=dict
)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
_nack_cancel_scope: CancelScope | None = field(init=False, default=None)
_nack_attempts: int = field(init=False, default=0)
_nack_base_seconds: float = field(init=False, default=0.5)
_nack_cap_seconds: float = field(init=False, default=10.0)
async def run(self):
try:
async with self._tg as tg:
tg.start_soon(self._run_ext_in)
tg.start_soon(self._simple_retry)
finally:
self.external_outbound.close()
for send in self.internal_outbound:
send.close()
# can make this better in future
async def _simple_retry(self):
while True:
await anyio.sleep(1 + random())
# list here is a shallow clone for shared mutation
for e_id, (time, event) in list(self.out_for_delivery.items()):
if anyio.current_time() > time + 5:
self.out_for_delivery[e_id] = (anyio.current_time(), event)
await self.external_outbound.send(event)
def sender(self) -> Sender[Event]:
send, recv = channel[Event]()
if self._tg.is_running():
self._tg.start_soon(self._ingest, SystemId(), recv)
else:
self._tg.queue(self._ingest, SystemId(), recv)
return send
def receiver(self) -> Receiver[IndexedEvent]:
send, recv = channel[IndexedEvent]()
self.internal_outbound.append(send)
return recv
def shutdown(self) -> None:
self._tg.cancel_tasks()
async def _ingest(self, system_id: SystemId, recv: Receiver[Event]):
idx = 0
with recv as events:
async for event in events:
f_ev = LocalForwarderEvent(
origin_idx=idx,
origin=system_id,
session=self.session_id,
event=event,
)
idx += 1
await self.external_outbound.send(f_ev)
self.out_for_delivery[event.event_id] = (anyio.current_time(), f_ev)
async def _run_ext_in(self):
buf = OrderedBuffer[Event]()
with self.external_inbound as events:
async for event in events:
if event.session != self.session_id:
continue
if event.origin != self.session_id.master_node_id:
continue
buf.ingest(event.origin_idx, event.event)
event_id = event.event.event_id
if event_id in self.out_for_delivery:
self.out_for_delivery.pop(event_id)
drained = buf.drain_indexed()
if drained:
self._nack_attempts = 0
if self._nack_cancel_scope:
self._nack_cancel_scope.cancel()
if not drained and (
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
# Request the next index.
self._tg.start_soon(self._nack_request, buf.next_idx_to_release)
continue
for idx, event in drained:
to_clear = set[int]()
for i, sender in enumerate(self.internal_outbound):
try:
await sender.send(IndexedEvent(idx=idx, event=event))
except (ClosedResourceError, BrokenResourceError):
to_clear.add(i)
for i in sorted(to_clear, reverse=True):
self.internal_outbound.pop(i)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
with CancelScope() as scope:
self._nack_cancel_scope = scope
delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)
delay = min(self._nack_cap_seconds, delay)
self._nack_attempts += 1
try:
await anyio.sleep(delay)
logger.info(
f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}"
)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
command=RequestEventLog(since_idx=since_idx),
)
)
finally:
if self._nack_cancel_scope is scope:
self._nack_cancel_scope = None

View File

@@ -90,6 +90,7 @@ class ModelCard(CamelCaseModel):
base_model: str = ""
capabilities: list[str] = []
uses_cfg: bool = False
trust_remote_code: bool = True
@field_validator("tasks", mode="before")
@classmethod
@@ -137,6 +138,7 @@ class ModelCard(CamelCaseModel):
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
tasks=[ModelTask.TextGeneration],
trust_remote_code=False,
)
await mc.save_to_custom_dir()
_card_cache[model_id] = mc

View File

@@ -13,5 +13,6 @@ KV_CACHE_BITS: int | None = None
DEFAULT_TOP_LOGPROBS: int = 5
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
# True for built-in models with known model cards; custom models added via API default to False
# and can be overridden with the --trust-remote-code CLI flag.
TRUST_REMOTE_CODE: bool = True

View File

@@ -23,9 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
TRUST_REMOTE_CODE,
)
from exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE
try:
from mlx_lm.tokenizer_utils import load_tokenizer
@@ -293,7 +291,15 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
trust_remote_code = (
shard_metadata.model_card.trust_remote_code
or os.environ.get("EXO_TRUST_REMOTE_CODE") == "1"
)
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=trust_remote_code,
)
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
@@ -325,7 +331,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
model_id: ModelId, model_path: Path, *, trust_remote_code: bool = TRUST_REMOTE_CODE
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.
@@ -394,7 +400,7 @@ def load_tokenizer_for_model_id(
tokenizer = load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
tokenizer_config_extra={"trust_remote_code": trust_remote_code},
eos_token_ids=eos_token_ids,
)

View File

@@ -1,8 +1,9 @@
from collections import defaultdict
from datetime import datetime, timezone
from random import random
import anyio
from anyio import fail_after
from anyio import CancelScope, fail_after
from loguru import logger
from exo.download.download_utils import resolve_model_in_path
@@ -12,12 +13,14 @@ from exo.shared.types.api import ImageEditsTaskParams
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SystemId
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
GlobalForwarderEvent,
IndexedEvent,
InputChunkReceived,
LocalForwarderEvent,
@@ -43,6 +46,7 @@ from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
@@ -55,27 +59,38 @@ class Worker:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
*,
event_receiver: Receiver[IndexedEvent],
event_sender: Sender[Event],
global_event_receiver: Receiver[GlobalForwarderEvent],
local_event_sender: Sender[LocalForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id: NodeId = node_id
self.event_receiver = event_receiver
self.event_sender = event_sender
self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {}
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = TaskGroup()
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
@@ -93,12 +108,14 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
logger.info("Stopping Worker")
self.event_sender.close()
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
@@ -116,22 +133,47 @@ class Worker:
)
async def _event_applier(self):
with self.event_receiver as events:
async for event in events:
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
event_id = f_event.event.event_id
if event_id in self.out_for_delivery:
del self.out_for_delivery[event_id]
# 2. for each event, apply it to the state
self.state = apply(self.state, event=event)
event = event.event
indexed_events = self.event_buffer.drain_indexed()
if indexed_events:
self._nack_attempts = 0
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
if not indexed_events and (
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
)
continue
elif indexed_events and self._nack_cancel_scope:
self._nack_cancel_scope.cancel()
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
async def plan_step(self):
while True:
@@ -283,6 +325,43 @@ class Worker:
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
with CancelScope() as scope:
self._nack_cancel_scope = scope
delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)
delay = min(self._nack_cap_seconds, delay)
self._nack_attempts += 1
try:
await anyio.sleep(delay)
logger.info(
f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}"
)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
command=RequestEventLog(since_idx=since_idx),
)
)
finally:
if self._nack_cancel_scope is scope:
self._nack_cancel_scope = None
async def _resend_out_for_delivery(self) -> None:
# This can also be massively tightened, we should check events are at least a certain age before resending.
# Exponential backoff would also certainly help here.
while True:
await anyio.sleep(1 + random())
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(
@@ -293,6 +372,21 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
async def _poll_connection_updates(self):
while True:
edges = set(

View File

@@ -106,13 +106,18 @@ class RunnerSupervisor:
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._tg.cancel_tasks()
self._ev_recv.close()
self._task_sender.close()
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
with contextlib.suppress(ClosedResourceError):
self._ev_recv.close()
with contextlib.suppress(ClosedResourceError):
self._task_sender.close()
with contextlib.suppress(ClosedResourceError):
self._event_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()
self.runner_process.join(5)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")

View File

@@ -0,0 +1,125 @@
#!/usr/bin/env bash
# Test that models added via API get trust_remote_code=false
# Run this against a running exo instance.
# Usage: ./test_trust_remote_code_attack.sh [host:port]
set -uo pipefail
HOST="${1:-localhost:52415}"
MODEL_ID="KevTheHermit/security-testing"
CUSTOM_CARDS_DIR="$HOME/.exo/custom_model_cards"
CARD_FILE="$CUSTOM_CARDS_DIR/KevTheHermit--security-testing.toml"
echo "=== Test: trust_remote_code attack via API ==="
echo "Target: $HOST"
echo ""
# Clean up RCE proof from previous runs
rm -f /tmp/exo-rce-proof.txt
# Step 0: Clean up any stale card from previous runs
if [ -f "$CARD_FILE" ]; then
echo "[0] Removing stale card from previous run ..."
curl -s -X DELETE \
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
rm -f "$CARD_FILE"
echo " Done"
echo ""
fi
# Step 1: Add the malicious model via API
echo "[1] Adding model via POST /models/add ..."
ADD_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/models/add" \
-H "Content-Type: application/json" \
-d "{\"model_id\":\"$MODEL_ID\"}")
HTTP_CODE=$(echo "$ADD_RESPONSE" | tail -1)
BODY=$(echo "$ADD_RESPONSE" | sed '$d')
echo " HTTP $HTTP_CODE"
if [ "$HTTP_CODE" -ge 400 ]; then
echo " Model add failed (HTTP $HTTP_CODE) — that's fine if model doesn't exist on HF."
echo " Response: $BODY"
echo ""
echo "RESULT: Model was rejected at add time. Attack blocked."
exit 0
fi
# Step 2: Verify the saved TOML has trust_remote_code = false
echo ""
echo "[2] Checking saved model card TOML ..."
if [ ! -f "$CARD_FILE" ]; then
echo " FAIL: Card file not found at $CARD_FILE"
exit 1
fi
if grep -q 'trust_remote_code = false' "$CARD_FILE"; then
echo " SAFE: trust_remote_code = false (fix is active)"
else
echo " VULNERABLE: trust_remote_code is not false — remote code WILL be trusted"
fi
echo " Contents:"
cat "$CARD_FILE"
# Step 3: Place the instance
echo ""
echo "[3] Attempting POST /place_instance ..."
PLACE_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/place_instance" \
-H "Content-Type: application/json" \
-d "{\"model_id\":\"$MODEL_ID\"}")
PLACE_CODE=$(echo "$PLACE_RESPONSE" | tail -1)
PLACE_BODY=$(echo "$PLACE_RESPONSE" | sed '$d')
echo " HTTP $PLACE_CODE"
echo " Response: $PLACE_BODY"
# Step 3b: Send a chat completion to actually trigger tokenizer loading
echo ""
echo "[3b] Sending chat completion to trigger tokenizer load ..."
CHAT_RESPONSE=$(curl -s -w "\n%{http_code}" --max-time 30 -X POST "http://$HOST/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{\"model\":\"$MODEL_ID\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}],\"max_tokens\":1}")
CHAT_CODE=$(echo "$CHAT_RESPONSE" | tail -1)
CHAT_BODY=$(echo "$CHAT_RESPONSE" | sed '$d')
echo " HTTP $CHAT_CODE"
echo " Response: $CHAT_BODY"
echo ""
echo "[3c] Checking for RCE proof ..."
sleep 5
if [ -f /tmp/exo-rce-proof.txt ]; then
echo " VULNERABLE: Remote code executed!"
echo " Contents:"
cat /tmp/exo-rce-proof.txt
else
echo " SAFE: /tmp/exo-rce-proof.txt does not exist — remote code was NOT executed"
fi
# Step 4: Clean up — delete instance and custom model
echo ""
echo "[4] Cleaning up ..."
# Find and delete any instance for this model
INSTANCE_ID=$(curl -s "http://$HOST/state" | python3 -c "
import sys, json
state = json.load(sys.stdin)
for iid, wrapper in state.get('instances', {}).items():
for tag, inst in wrapper.items():
sa = inst.get('shardAssignments', {})
if sa.get('modelId', '') == '$MODEL_ID':
print(iid)
sys.exit(0)
" 2>/dev/null || true)
if [ -n "$INSTANCE_ID" ]; then
echo " Deleting instance $INSTANCE_ID ..."
curl -s -X DELETE "http://$HOST/instance/$INSTANCE_ID" >/dev/null
echo " Done"
else
echo " No instance found to delete"
fi
echo " Deleting custom model card ..."
curl -s -X DELETE \
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
echo " Done"
echo ""
echo "=== DONE ==="