Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
dc9000290b fix instance type mismatch in /instance/previews endpoint
Derive instance_meta from the actual instance type returned by
place_instance() instead of using the requested instance_meta from
the loop variable. place_instance() overrides single-node placements
to MlxRing, but the preview response was still reporting the original
requested type (e.g., MlxJaccl), causing a mismatch.

Closes #1426

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 18:04:38 +00:00
10 changed files with 50 additions and 354 deletions

View File

@@ -469,11 +469,11 @@
<td class="px-4 py-3 text-center align-middle">
{#if cell.kind === "completed"}
<div
class="flex flex-col items-center gap-1"
class="flex flex-col items-center gap-0.5"
title="Completed ({formatBytes(cell.totalBytes)})"
>
<svg
class="w-7 h-7 text-green-400"
class="w-5 h-5 text-green-400"
viewBox="0 0 20 20"
fill="currentColor"
>
@@ -483,18 +483,18 @@
clip-rule="evenodd"
></path>
</svg>
<span class="text-xs text-exo-light-gray/70"
<span class="text-[10px] text-exo-light-gray/70"
>{formatBytes(cell.totalBytes)}</span
>
<button
type="button"
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5 cursor-pointer"
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5"
onclick={() =>
deleteDownload(col.nodeId, row.modelId)}
title="Delete from this node"
>
<svg
class="w-5 h-5"
class="w-3.5 h-3.5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
@@ -517,11 +517,11 @@
cell.speed,
)} - ETA {formatEta(cell.etaMs)}"
>
<span class="text-exo-yellow text-sm font-medium"
<span class="text-exo-yellow text-xs font-medium"
>{clampPercent(cell.percentage).toFixed(1)}%</span
>
<div
class="w-16 h-2 bg-exo-black/60 rounded-sm overflow-hidden"
class="w-14 h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="h-full bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
@@ -530,25 +530,25 @@
).toFixed(1)}%"
></div>
</div>
<span class="text-[10px] text-exo-light-gray/60"
<span class="text-[9px] text-exo-light-gray/60"
>{formatSpeed(cell.speed)}</span
>
</div>
{:else if cell.kind === "pending"}
<div
class="flex flex-col items-center gap-1"
class="flex flex-col items-center gap-0.5"
title={cell.downloaded > 0
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded (paused)`
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded`
: "Download pending"}
>
{#if cell.downloaded > 0 && cell.total > 0}
<span class="text-exo-light-gray/70 text-xs"
<span class="text-exo-light-gray/70 text-[10px]"
>{formatBytes(cell.downloaded)} / {formatBytes(
cell.total,
)}</span
>
<div
class="w-full h-1.5 bg-white/10 rounded-full overflow-hidden"
class="w-full h-1 bg-white/10 rounded-full overflow-hidden"
>
<div
class="h-full bg-exo-light-gray/40 rounded-full"
@@ -558,55 +558,9 @@
).toFixed(1)}%"
></div>
</div>
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/50 hover:text-exo-yellow transition-colors cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Resume download on this node"
>
<svg
class="w-5 h-5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{:else}
<span class="text-exo-light-gray/40 text-[10px]"
>paused</span
>
{/if}
{:else if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Start download on this node"
<span class="text-exo-light-gray/40 text-[9px]"
>paused</span
>
<svg
class="w-6 h-6"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{:else}
<span class="text-exo-light-gray/50 text-sm">...</span
>
@@ -614,11 +568,11 @@
</div>
{:else if cell.kind === "failed"}
<div
class="flex flex-col items-center gap-1"
class="flex flex-col items-center gap-0.5"
title="Download failed"
>
<svg
class="w-7 h-7 text-red-400"
class="w-5 h-5 text-red-400"
viewBox="0 0 20 20"
fill="currentColor"
>
@@ -631,13 +585,13 @@
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors cursor-pointer"
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Retry download on this node"
>
<svg
class="w-5 h-5"
class="w-3.5 h-3.5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
@@ -663,13 +617,13 @@
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100 cursor-pointer"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Download to this node"
>
<svg
class="w-5 h-5"
class="w-3.5 h-3.5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"

View File

@@ -823,7 +823,6 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
final_file_exists = await aios.path.exists(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=shard.model_card.model_id,
repo_revision=revision,
@@ -833,9 +832,7 @@ async def download_shard(
total=Memory.from_bytes(file.size or 0),
speed=0,
eta=timedelta(0),
status="complete"
if final_file_exists and downloaded_bytes == file.size
else "not_started",
status="complete" if downloaded_bytes == file.size else "not_started",
start_time=time.time(),
)

View File

@@ -252,7 +252,7 @@ def main():
target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
@@ -261,13 +261,6 @@ def main():
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set trust_remote_code override env var for runner subprocesses
if args.trust_remote_code:
os.environ["EXO_TRUST_REMOTE_CODE"] = "1"
logger.warning(
"--trust-remote-code enabled: models may execute arbitrary code during loading"
)
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
@@ -292,7 +285,6 @@ class Args(CamelCaseModel):
no_downloads: bool = False
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
trust_remote_code: bool = False
@classmethod
def parse(cls) -> Self:
@@ -344,11 +336,6 @@ class Args(CamelCaseModel):
action="store_true",
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Allow models to execute custom code during tokenizer loading (security-sensitive, CLI-only)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -168,7 +168,12 @@ from exo.shared.types.openai_responses import (
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxJacclInstance,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -513,6 +518,14 @@ class API:
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
# Derive instance_meta from the actual instance type, since
# place_instance() may override it (e.g., single-node → MlxRing)
actual_instance_meta = (
InstanceMeta.MlxJaccl
if isinstance(instance, MlxJacclInstance)
else InstanceMeta.MlxRing
)
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
@@ -525,14 +538,14 @@ class API:
if (
model_card.model_id,
sharding,
instance_meta,
actual_instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance_meta=actual_instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
@@ -542,7 +555,7 @@ class API:
(
model_card.model_id,
sharding,
instance_meta,
actual_instance_meta,
len(placement_node_ids),
)
)

View File

@@ -14,12 +14,10 @@ from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted, TaskStatusUpdated
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
Instance,
@@ -458,117 +456,3 @@ def test_tensor_rdma_backend_connectivity_matrix(
else:
ip_part = coordinator.split(":")[0]
assert len(ip_part.split(".")) == 4
def _make_task(
instance_id: InstanceId,
status: TaskStatus = TaskStatus.Running,
) -> TextGeneration:
return TextGeneration(
task_id=TaskId(),
task_status=status,
instance_id=instance_id,
command_id=CommandId(),
task_params=TextGenerationTaskParams(
model=ModelId("test-model"),
input=[InputMessage(role="user", content="hello")],
),
)
def test_get_transition_events_delete_instance_cancels_running_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
task = _make_task(instance_id, TaskStatus.Running)
tasks = {task.task_id: task}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert cancellation event should come before the deletion event
assert len(events) == 2
assert isinstance(events[0], TaskStatusUpdated)
assert events[0].task_id == task.task_id
assert events[0].task_status == TaskStatus.Cancelled
assert isinstance(events[1], InstanceDeleted)
assert events[1].instance_id == instance_id
def test_get_transition_events_delete_instance_cancels_pending_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
task = _make_task(instance_id, TaskStatus.Pending)
tasks = {task.task_id: task}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert
assert len(events) == 2
assert isinstance(events[0], TaskStatusUpdated)
assert events[0].task_id == task.task_id
assert events[0].task_status == TaskStatus.Cancelled
assert isinstance(events[1], InstanceDeleted)
def test_get_transition_events_delete_instance_ignores_completed_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
tasks = {
t.task_id: t
for t in [
_make_task(instance_id, TaskStatus.Complete),
_make_task(instance_id, TaskStatus.Failed),
_make_task(instance_id, TaskStatus.TimedOut),
_make_task(instance_id, TaskStatus.Cancelled),
]
}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert only the InstanceDeleted event, no cancellations
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
def test_get_transition_events_delete_instance_cancels_only_matching_tasks(
instance: Instance,
):
# arrange
instance_id_a = InstanceId()
instance_id_b = InstanceId()
current_instances: dict[InstanceId, Instance] = {
instance_id_a: instance,
instance_id_b: instance,
}
# only delete instance A, keep instance B
target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}
task_a = _make_task(instance_id_a, TaskStatus.Running)
task_b = _make_task(instance_id_b, TaskStatus.Running)
tasks = {task_a.task_id: task_a, task_b.task_id: task_b}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert only task_a should be cancelled
cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]
delete_events = [e for e in events if isinstance(e, InstanceDeleted)]
assert len(cancel_events) == 1
assert cancel_events[0].task_id == task_a.task_id
assert cancel_events[0].task_status == TaskStatus.Cancelled
assert len(delete_events) == 1
assert delete_events[0].instance_id == instance_id_a

View File

@@ -90,7 +90,6 @@ class ModelCard(CamelCaseModel):
base_model: str = ""
capabilities: list[str] = []
uses_cfg: bool = False
trust_remote_code: bool = True
@field_validator("tasks", mode="before")
@classmethod
@@ -138,7 +137,6 @@ class ModelCard(CamelCaseModel):
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
tasks=[ModelTask.TextGeneration],
trust_remote_code=False,
)
await mc.save_to_custom_dir()
_card_cache[model_id] = mc

View File

@@ -13,6 +13,5 @@ KV_CACHE_BITS: int | None = None
DEFAULT_TOP_LOGPROBS: int = 5
# True for built-in models with known model cards; custom models added via API default to False
# and can be overridden with the --trust-remote-code CLI flag.
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -23,7 +23,9 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE
from exo.worker.engines.mlx.constants import (
TRUST_REMOTE_CODE,
)
try:
from mlx_lm.tokenizer_utils import load_tokenizer
@@ -291,15 +293,7 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
trust_remote_code = (
shard_metadata.model_card.trust_remote_code
or os.environ.get("EXO_TRUST_REMOTE_CODE") == "1"
)
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=trust_remote_code,
)
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
@@ -331,7 +325,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path, *, trust_remote_code: bool = TRUST_REMOTE_CODE
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.
@@ -400,7 +394,7 @@ def load_tokenizer_for_model_id(
tokenizer = load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": trust_remote_code},
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
)

View File

@@ -106,18 +106,13 @@ class RunnerSupervisor:
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._tg.cancel_tasks()
self._ev_recv.close()
self._task_sender.close()
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
with contextlib.suppress(ClosedResourceError):
self._ev_recv.close()
with contextlib.suppress(ClosedResourceError):
self._task_sender.close()
with contextlib.suppress(ClosedResourceError):
self._event_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()
self._cancel_sender.close()
self.runner_process.join(5)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")

View File

@@ -1,125 +0,0 @@
#!/usr/bin/env bash
# Test that models added via API get trust_remote_code=false
# Run this against a running exo instance.
# Usage: ./test_trust_remote_code_attack.sh [host:port]
set -uo pipefail
HOST="${1:-localhost:52415}"
MODEL_ID="KevTheHermit/security-testing"
CUSTOM_CARDS_DIR="$HOME/.exo/custom_model_cards"
CARD_FILE="$CUSTOM_CARDS_DIR/KevTheHermit--security-testing.toml"
echo "=== Test: trust_remote_code attack via API ==="
echo "Target: $HOST"
echo ""
# Clean up RCE proof from previous runs
rm -f /tmp/exo-rce-proof.txt
# Step 0: Clean up any stale card from previous runs
if [ -f "$CARD_FILE" ]; then
echo "[0] Removing stale card from previous run ..."
curl -s -X DELETE \
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
rm -f "$CARD_FILE"
echo " Done"
echo ""
fi
# Step 1: Add the malicious model via API
echo "[1] Adding model via POST /models/add ..."
ADD_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/models/add" \
-H "Content-Type: application/json" \
-d "{\"model_id\":\"$MODEL_ID\"}")
HTTP_CODE=$(echo "$ADD_RESPONSE" | tail -1)
BODY=$(echo "$ADD_RESPONSE" | sed '$d')
echo " HTTP $HTTP_CODE"
if [ "$HTTP_CODE" -ge 400 ]; then
echo " Model add failed (HTTP $HTTP_CODE) — that's fine if model doesn't exist on HF."
echo " Response: $BODY"
echo ""
echo "RESULT: Model was rejected at add time. Attack blocked."
exit 0
fi
# Step 2: Verify the saved TOML has trust_remote_code = false
echo ""
echo "[2] Checking saved model card TOML ..."
if [ ! -f "$CARD_FILE" ]; then
echo " FAIL: Card file not found at $CARD_FILE"
exit 1
fi
if grep -q 'trust_remote_code = false' "$CARD_FILE"; then
echo " SAFE: trust_remote_code = false (fix is active)"
else
echo " VULNERABLE: trust_remote_code is not false — remote code WILL be trusted"
fi
echo " Contents:"
cat "$CARD_FILE"
# Step 3: Place the instance
echo ""
echo "[3] Attempting POST /place_instance ..."
PLACE_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/place_instance" \
-H "Content-Type: application/json" \
-d "{\"model_id\":\"$MODEL_ID\"}")
PLACE_CODE=$(echo "$PLACE_RESPONSE" | tail -1)
PLACE_BODY=$(echo "$PLACE_RESPONSE" | sed '$d')
echo " HTTP $PLACE_CODE"
echo " Response: $PLACE_BODY"
# Step 3b: Send a chat completion to actually trigger tokenizer loading
echo ""
echo "[3b] Sending chat completion to trigger tokenizer load ..."
CHAT_RESPONSE=$(curl -s -w "\n%{http_code}" --max-time 30 -X POST "http://$HOST/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{\"model\":\"$MODEL_ID\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}],\"max_tokens\":1}")
CHAT_CODE=$(echo "$CHAT_RESPONSE" | tail -1)
CHAT_BODY=$(echo "$CHAT_RESPONSE" | sed '$d')
echo " HTTP $CHAT_CODE"
echo " Response: $CHAT_BODY"
echo ""
echo "[3c] Checking for RCE proof ..."
sleep 5
if [ -f /tmp/exo-rce-proof.txt ]; then
echo " VULNERABLE: Remote code executed!"
echo " Contents:"
cat /tmp/exo-rce-proof.txt
else
echo " SAFE: /tmp/exo-rce-proof.txt does not exist — remote code was NOT executed"
fi
# Step 4: Clean up — delete instance and custom model
echo ""
echo "[4] Cleaning up ..."
# Find and delete any instance for this model
INSTANCE_ID=$(curl -s "http://$HOST/state" | python3 -c "
import sys, json
state = json.load(sys.stdin)
for iid, wrapper in state.get('instances', {}).items():
for tag, inst in wrapper.items():
sa = inst.get('shardAssignments', {})
if sa.get('modelId', '') == '$MODEL_ID':
print(iid)
sys.exit(0)
" 2>/dev/null || true)
if [ -n "$INSTANCE_ID" ]; then
echo " Deleting instance $INSTANCE_ID ..."
curl -s -X DELETE "http://$HOST/instance/$INSTANCE_ID" >/dev/null
echo " Done"
else
echo " No instance found to delete"
fi
echo " Deleting custom model card ..."
curl -s -X DELETE \
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
echo " Done"
echo ""
echo "=== DONE ==="