Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
dc9000290b fix instance type mismatch in /instance/previews endpoint
Derive instance_meta from the actual instance type returned by
place_instance() instead of using the requested instance_meta from
the loop variable. place_instance() overrides single-node placements
to MlxRing, but the preview response was still reporting the original
requested type (e.g., MlxJaccl), causing a mismatch.

Closes #1426

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 18:04:38 +00:00
10 changed files with 50 additions and 220 deletions

View File

@@ -469,11 +469,11 @@
<td class="px-4 py-3 text-center align-middle">
{#if cell.kind === "completed"}
<div
class="flex flex-col items-center gap-1"
class="flex flex-col items-center gap-0.5"
title="Completed ({formatBytes(cell.totalBytes)})"
>
<svg
class="w-7 h-7 text-green-400"
class="w-5 h-5 text-green-400"
viewBox="0 0 20 20"
fill="currentColor"
>
@@ -483,18 +483,18 @@
clip-rule="evenodd"
></path>
</svg>
<span class="text-xs text-exo-light-gray/70"
<span class="text-[10px] text-exo-light-gray/70"
>{formatBytes(cell.totalBytes)}</span
>
<button
type="button"
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5 cursor-pointer"
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5"
onclick={() =>
deleteDownload(col.nodeId, row.modelId)}
title="Delete from this node"
>
<svg
class="w-5 h-5"
class="w-3.5 h-3.5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
@@ -517,11 +517,11 @@
cell.speed,
)} - ETA {formatEta(cell.etaMs)}"
>
<span class="text-exo-yellow text-sm font-medium"
<span class="text-exo-yellow text-xs font-medium"
>{clampPercent(cell.percentage).toFixed(1)}%</span
>
<div
class="w-16 h-2 bg-exo-black/60 rounded-sm overflow-hidden"
class="w-14 h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="h-full bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
@@ -530,25 +530,25 @@
).toFixed(1)}%"
></div>
</div>
<span class="text-[10px] text-exo-light-gray/60"
<span class="text-[9px] text-exo-light-gray/60"
>{formatSpeed(cell.speed)}</span
>
</div>
{:else if cell.kind === "pending"}
<div
class="flex flex-col items-center gap-1"
class="flex flex-col items-center gap-0.5"
title={cell.downloaded > 0
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded (paused)`
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded`
: "Download pending"}
>
{#if cell.downloaded > 0 && cell.total > 0}
<span class="text-exo-light-gray/70 text-xs"
<span class="text-exo-light-gray/70 text-[10px]"
>{formatBytes(cell.downloaded)} / {formatBytes(
cell.total,
)}</span
>
<div
class="w-full h-1.5 bg-white/10 rounded-full overflow-hidden"
class="w-full h-1 bg-white/10 rounded-full overflow-hidden"
>
<div
class="h-full bg-exo-light-gray/40 rounded-full"
@@ -558,55 +558,9 @@
).toFixed(1)}%"
></div>
</div>
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/50 hover:text-exo-yellow transition-colors cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Resume download on this node"
>
<svg
class="w-5 h-5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{:else}
<span class="text-exo-light-gray/40 text-[10px]"
>paused</span
>
{/if}
{:else if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Start download on this node"
<span class="text-exo-light-gray/40 text-[9px]"
>paused</span
>
<svg
class="w-6 h-6"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{:else}
<span class="text-exo-light-gray/50 text-sm">...</span
>
@@ -614,11 +568,11 @@
</div>
{:else if cell.kind === "failed"}
<div
class="flex flex-col items-center gap-1"
class="flex flex-col items-center gap-0.5"
title="Download failed"
>
<svg
class="w-7 h-7 text-red-400"
class="w-5 h-5 text-red-400"
viewBox="0 0 20 20"
fill="currentColor"
>
@@ -631,13 +585,13 @@
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors cursor-pointer"
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Retry download on this node"
>
<svg
class="w-5 h-5"
class="w-3.5 h-3.5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
@@ -663,13 +617,13 @@
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100 cursor-pointer"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Download to this node"
>
<svg
class="w-5 h-5"
class="w-3.5 h-3.5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"

View File

@@ -823,7 +823,6 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
final_file_exists = await aios.path.exists(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=shard.model_card.model_id,
repo_revision=revision,
@@ -833,9 +832,7 @@ async def download_shard(
total=Memory.from_bytes(file.size or 0),
speed=0,
eta=timedelta(0),
status="complete"
if final_file_exists and downloaded_bytes == file.size
else "not_started",
status="complete" if downloaded_bytes == file.size else "not_started",
start_time=time.time(),
)

View File

@@ -252,7 +252,7 @@ def main():
target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")

View File

@@ -168,7 +168,12 @@ from exo.shared.types.openai_responses import (
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxJacclInstance,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -513,6 +518,14 @@ class API:
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
# Derive instance_meta from the actual instance type, since
# place_instance() may override it (e.g., single-node → MlxRing)
actual_instance_meta = (
InstanceMeta.MlxJaccl
if isinstance(instance, MlxJacclInstance)
else InstanceMeta.MlxRing
)
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
@@ -525,14 +538,14 @@ class API:
if (
model_card.model_id,
sharding,
instance_meta,
actual_instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance_meta=actual_instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
@@ -542,7 +555,7 @@ class API:
(
model_card.model_id,
sharding,
instance_meta,
actual_instance_meta,
len(placement_node_ids),
)
)

View File

@@ -90,7 +90,6 @@ class ModelCard(CamelCaseModel):
base_model: str = ""
capabilities: list[str] = []
uses_cfg: bool = False
trust_remote_code: bool = True
@field_validator("tasks", mode="before")
@classmethod
@@ -138,7 +137,6 @@ class ModelCard(CamelCaseModel):
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
tasks=[ModelTask.TextGeneration],
trust_remote_code=False,
)
await mc.save_to_custom_dir()
_card_cache[model_id] = mc

View File

@@ -23,7 +23,9 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE
from exo.worker.engines.mlx.constants import (
TRUST_REMOTE_CODE,
)
try:
from mlx_lm.tokenizer_utils import load_tokenizer
@@ -291,11 +293,7 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=shard_metadata.model_card.trust_remote_code,
)
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
@@ -327,7 +325,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path, *, trust_remote_code: bool = TRUST_REMOTE_CODE
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.
@@ -396,7 +394,7 @@ def load_tokenizer_for_model_id(
tokenizer = load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": trust_remote_code},
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
)

View File

@@ -106,18 +106,13 @@ class RunnerSupervisor:
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._tg.cancel_tasks()
self._ev_recv.close()
self._task_sender.close()
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
with contextlib.suppress(ClosedResourceError):
self._ev_recv.close()
with contextlib.suppress(ClosedResourceError):
self._task_sender.close()
with contextlib.suppress(ClosedResourceError):
self._event_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()
self._cancel_sender.close()
self.runner_process.join(5)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")

View File

@@ -8,7 +8,7 @@ from urllib.request import urlopen
h = sys.argv[1] if len(sys.argv) > 1 else sys.exit(f"USAGE: {sys.argv[0]} host")
ts = subprocess.run(
["/Applications/Tailscale.app/Contents/MacOS/Tailscale", "status"], check=True, text=True, capture_output=True
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = next(
(sl[0] for line in ts if len(sl := line.split()) >= 2 if sl[1] == h), None

View File

@@ -15,7 +15,7 @@ if not (args := sys.argv[1:]):
kind = args[0] if args[0] in ("jaccl", "ring") else "both"
hosts = args[1:] if kind != "both" else args
ts = subprocess.run(
["/Applications/Tailscale.app/Contents/MacOS/Tailscale", "status"], check=True, text=True, capture_output=True
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
ips = [ip[h] for h in hosts]

View File

@@ -1,125 +0,0 @@
#!/usr/bin/env bash
# Test that models added via API get trust_remote_code=false
# Run this against a running exo instance.
# Usage: ./test_trust_remote_code_attack.sh [host:port]
set -uo pipefail
HOST="${1:-localhost:52415}"
MODEL_ID="KevTheHermit/security-testing"
CUSTOM_CARDS_DIR="$HOME/.exo/custom_model_cards"
CARD_FILE="$CUSTOM_CARDS_DIR/KevTheHermit--security-testing.toml"
echo "=== Test: trust_remote_code attack via API ==="
echo "Target: $HOST"
echo ""
# Clean up RCE proof from previous runs
rm -f /tmp/exo-rce-proof.txt
# Step 0: Clean up any stale card from previous runs
if [ -f "$CARD_FILE" ]; then
echo "[0] Removing stale card from previous run ..."
curl -s -X DELETE \
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
rm -f "$CARD_FILE"
echo " Done"
echo ""
fi
# Step 1: Add the malicious model via API
echo "[1] Adding model via POST /models/add ..."
ADD_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/models/add" \
-H "Content-Type: application/json" \
-d "{\"model_id\":\"$MODEL_ID\"}")
HTTP_CODE=$(echo "$ADD_RESPONSE" | tail -1)
BODY=$(echo "$ADD_RESPONSE" | sed '$d')
echo " HTTP $HTTP_CODE"
if [ "$HTTP_CODE" -ge 400 ]; then
echo " Model add failed (HTTP $HTTP_CODE) — that's fine if model doesn't exist on HF."
echo " Response: $BODY"
echo ""
echo "RESULT: Model was rejected at add time. Attack blocked."
exit 0
fi
# Step 2: Verify the saved TOML has trust_remote_code = false
echo ""
echo "[2] Checking saved model card TOML ..."
if [ ! -f "$CARD_FILE" ]; then
echo " FAIL: Card file not found at $CARD_FILE"
exit 1
fi
if grep -q 'trust_remote_code = false' "$CARD_FILE"; then
echo " SAFE: trust_remote_code = false (fix is active)"
else
echo " VULNERABLE: trust_remote_code is not false — remote code WILL be trusted"
fi
echo " Contents:"
cat "$CARD_FILE"
# Step 3: Place the instance
echo ""
echo "[3] Attempting POST /place_instance ..."
PLACE_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/place_instance" \
-H "Content-Type: application/json" \
-d "{\"model_id\":\"$MODEL_ID\"}")
PLACE_CODE=$(echo "$PLACE_RESPONSE" | tail -1)
PLACE_BODY=$(echo "$PLACE_RESPONSE" | sed '$d')
echo " HTTP $PLACE_CODE"
echo " Response: $PLACE_BODY"
# Step 3b: Send a chat completion to actually trigger tokenizer loading
echo ""
echo "[3b] Sending chat completion to trigger tokenizer load ..."
CHAT_RESPONSE=$(curl -s -w "\n%{http_code}" --max-time 30 -X POST "http://$HOST/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{\"model\":\"$MODEL_ID\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}],\"max_tokens\":1}")
CHAT_CODE=$(echo "$CHAT_RESPONSE" | tail -1)
CHAT_BODY=$(echo "$CHAT_RESPONSE" | sed '$d')
echo " HTTP $CHAT_CODE"
echo " Response: $CHAT_BODY"
echo ""
echo "[3c] Checking for RCE proof ..."
sleep 5
if [ -f /tmp/exo-rce-proof.txt ]; then
echo " VULNERABLE: Remote code executed!"
echo " Contents:"
cat /tmp/exo-rce-proof.txt
else
echo " SAFE: /tmp/exo-rce-proof.txt does not exist — remote code was NOT executed"
fi
# Step 4: Clean up — delete instance and custom model
echo ""
echo "[4] Cleaning up ..."
# Find and delete any instance for this model
INSTANCE_ID=$(curl -s "http://$HOST/state" | python3 -c "
import sys, json
state = json.load(sys.stdin)
for iid, wrapper in state.get('instances', {}).items():
for tag, inst in wrapper.items():
sa = inst.get('shardAssignments', {})
if sa.get('modelId', '') == '$MODEL_ID':
print(iid)
sys.exit(0)
" 2>/dev/null || true)
if [ -n "$INSTANCE_ID" ]; then
echo " Deleting instance $INSTANCE_ID ..."
curl -s -X DELETE "http://$HOST/instance/$INSTANCE_ID" >/dev/null
echo " Done"
else
echo " No instance found to delete"
fi
echo " Deleting custom model card ..."
curl -s -X DELETE \
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
echo " Done"
echo ""
echo "=== DONE ==="