Compare commits

..

4 Commits

Author SHA1 Message Date
Alex Cheema
daf2f9f48e fix: keep TRUST_REMOTE_CODE=True for built-in models
The constant is the default for built-in models with known model cards,
which are trusted. Custom models added via API already default to
trust_remote_code=False in ModelCard.fetch_from_hf(). The CLI flag
overrides custom models only.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 13:21:56 -08:00
Alex Cheema
5617aed345 feat: add --trust-remote-code CLI flag for custom model tokenizers
Some custom models (e.g. Kimi) require trust_remote_code=True to load
their tokenizers. This adds an opt-in CLI flag that sets an env var
read by runner subprocesses, following the same pattern as --fast-synch.
The flag is intentionally CLI-only (not API-accessible) to prevent
remote code execution attacks via the API.

Also changes the default TRUST_REMOTE_CODE constant from True to False,
making remote code execution fully opt-in.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 13:17:52 -08:00
rltakashige
365dd68d9a Final fixes for release (#1603)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-23 21:10:15 +00:00
Alex Cheema
d3d129581e test: verify instance deletion cancels ongoing tasks (#1508)
## Summary
- The cancellation logic for issue #1215 already exists in
`get_transition_events()` (`src/exo/master/placement.py:208-227`) — when
an instance is deleted, `TaskStatusUpdated(Cancelled)` events are
emitted for all Pending/Running tasks on that instance
- Combined with PR #1276's token-boundary cancellation in runners, the
full pipeline works end-to-end
- However, the existing test
`test_get_transition_events_delete_instance` passed `{}` for tasks, so
this path was never exercised
- This PR adds 4 tests covering the cancellation behavior:
  - Running tasks are cancelled on instance deletion
  - Pending tasks are cancelled on instance deletion
  - Completed/Failed/TimedOut/Cancelled tasks are left alone
  - Only tasks matching the deleted instance are cancelled

Closes #1215

## Test plan
- [x] `uv run pytest src/exo/master/tests/test_placement.py -v` — all 15
tests pass
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 20:12:23 +00:00
6 changed files with 139 additions and 5 deletions

View File

@@ -261,6 +261,13 @@ def main():
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set trust_remote_code override env var for runner subprocesses
if args.trust_remote_code:
os.environ["EXO_TRUST_REMOTE_CODE"] = "1"
logger.warning(
"--trust-remote-code enabled: models may execute arbitrary code during loading"
)
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
@@ -285,6 +292,7 @@ class Args(CamelCaseModel):
no_downloads: bool = False
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
trust_remote_code: bool = False
@classmethod
def parse(cls) -> Self:
@@ -336,6 +344,11 @@ class Args(CamelCaseModel):
action="store_true",
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Allow models to execute custom code during tokenizer loading (security-sensitive, CLI-only)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -14,10 +14,12 @@ from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.events import InstanceCreated, InstanceDeleted, TaskStatusUpdated
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
Instance,
@@ -456,3 +458,117 @@ def test_tensor_rdma_backend_connectivity_matrix(
else:
ip_part = coordinator.split(":")[0]
assert len(ip_part.split(".")) == 4
def _make_task(
instance_id: InstanceId,
status: TaskStatus = TaskStatus.Running,
) -> TextGeneration:
return TextGeneration(
task_id=TaskId(),
task_status=status,
instance_id=instance_id,
command_id=CommandId(),
task_params=TextGenerationTaskParams(
model=ModelId("test-model"),
input=[InputMessage(role="user", content="hello")],
),
)
def test_get_transition_events_delete_instance_cancels_running_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
task = _make_task(instance_id, TaskStatus.Running)
tasks = {task.task_id: task}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert cancellation event should come before the deletion event
assert len(events) == 2
assert isinstance(events[0], TaskStatusUpdated)
assert events[0].task_id == task.task_id
assert events[0].task_status == TaskStatus.Cancelled
assert isinstance(events[1], InstanceDeleted)
assert events[1].instance_id == instance_id
def test_get_transition_events_delete_instance_cancels_pending_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
task = _make_task(instance_id, TaskStatus.Pending)
tasks = {task.task_id: task}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert
assert len(events) == 2
assert isinstance(events[0], TaskStatusUpdated)
assert events[0].task_id == task.task_id
assert events[0].task_status == TaskStatus.Cancelled
assert isinstance(events[1], InstanceDeleted)
def test_get_transition_events_delete_instance_ignores_completed_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
tasks = {
t.task_id: t
for t in [
_make_task(instance_id, TaskStatus.Complete),
_make_task(instance_id, TaskStatus.Failed),
_make_task(instance_id, TaskStatus.TimedOut),
_make_task(instance_id, TaskStatus.Cancelled),
]
}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert only the InstanceDeleted event, no cancellations
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
def test_get_transition_events_delete_instance_cancels_only_matching_tasks(
instance: Instance,
):
# arrange
instance_id_a = InstanceId()
instance_id_b = InstanceId()
current_instances: dict[InstanceId, Instance] = {
instance_id_a: instance,
instance_id_b: instance,
}
# only delete instance A, keep instance B
target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}
task_a = _make_task(instance_id_a, TaskStatus.Running)
task_b = _make_task(instance_id_b, TaskStatus.Running)
tasks = {task_a.task_id: task_a, task_b.task_id: task_b}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert only task_a should be cancelled
cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]
delete_events = [e for e in events if isinstance(e, InstanceDeleted)]
assert len(cancel_events) == 1
assert cancel_events[0].task_id == task_a.task_id
assert cancel_events[0].task_status == TaskStatus.Cancelled
assert len(delete_events) == 1
assert delete_events[0].instance_id == instance_id_a

View File

@@ -13,5 +13,6 @@ KV_CACHE_BITS: int | None = None
DEFAULT_TOP_LOGPROBS: int = 5
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
# True for built-in models with known model cards; custom models added via API default to False
# and can be overridden with the --trust-remote-code CLI flag.
TRUST_REMOTE_CODE: bool = True

View File

@@ -291,10 +291,14 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
trust_remote_code = (
shard_metadata.model_card.trust_remote_code
or os.environ.get("EXO_TRUST_REMOTE_CODE") == "1"
)
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=shard_metadata.model_card.trust_remote_code,
trust_remote_code=trust_remote_code,
)

View File

@@ -8,7 +8,7 @@ from urllib.request import urlopen
h = sys.argv[1] if len(sys.argv) > 1 else sys.exit(f"USAGE: {sys.argv[0]} host")
ts = subprocess.run(
["/Applications/Tailscale.app/Contents/MacOS/Tailscale", "status"], check=True, text=True, capture_output=True
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = next(
(sl[0] for line in ts if len(sl := line.split()) >= 2 if sl[1] == h), None

View File

@@ -15,7 +15,7 @@ if not (args := sys.argv[1:]):
kind = args[0] if args[0] in ("jaccl", "ring") else "both"
hosts = args[1:] if kind != "both" else args
ts = subprocess.run(
["/Applications/Tailscale.app/Contents/MacOS/Tailscale", "status"], check=True, text=True, capture_output=True
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
ips = [ip[h] for h in hosts]