mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-23 17:58:36 -05:00
Compare commits
4 Commits
leo/test-b
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
daf2f9f48e | ||
|
|
5617aed345 | ||
|
|
365dd68d9a | ||
|
|
d3d129581e |
@@ -261,6 +261,13 @@ def main():
|
||||
if args.offline:
|
||||
logger.info("Running in OFFLINE mode — no internet checks, local models only")
|
||||
|
||||
# Set trust_remote_code override env var for runner subprocesses
|
||||
if args.trust_remote_code:
|
||||
os.environ["EXO_TRUST_REMOTE_CODE"] = "1"
|
||||
logger.warning(
|
||||
"--trust-remote-code enabled: models may execute arbitrary code during loading"
|
||||
)
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
@@ -285,6 +292,7 @@ class Args(CamelCaseModel):
|
||||
no_downloads: bool = False
|
||||
offline: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
trust_remote_code: bool = False
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -336,6 +344,11 @@ class Args(CamelCaseModel):
|
||||
action="store_true",
|
||||
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Allow models to execute custom code during tokenizer loading (security-sensitive, CLI-only)",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -14,10 +14,12 @@ from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted, TaskStatusUpdated
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -456,3 +458,117 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
else:
|
||||
ip_part = coordinator.split(":")[0]
|
||||
assert len(ip_part.split(".")) == 4
|
||||
|
||||
|
||||
def _make_task(
|
||||
instance_id: InstanceId,
|
||||
status: TaskStatus = TaskStatus.Running,
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(),
|
||||
task_status=status,
|
||||
instance_id=instance_id,
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=ModelId("test-model"),
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_running_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Running)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – cancellation event should come before the deletion event
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
assert events[1].instance_id == instance_id
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_pending_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Pending)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_ignores_completed_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
tasks = {
|
||||
t.task_id: t
|
||||
for t in [
|
||||
_make_task(instance_id, TaskStatus.Complete),
|
||||
_make_task(instance_id, TaskStatus.Failed),
|
||||
_make_task(instance_id, TaskStatus.TimedOut),
|
||||
_make_task(instance_id, TaskStatus.Cancelled),
|
||||
]
|
||||
}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only the InstanceDeleted event, no cancellations
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_only_matching_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id_a = InstanceId()
|
||||
instance_id_b = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {
|
||||
instance_id_a: instance,
|
||||
instance_id_b: instance,
|
||||
}
|
||||
# only delete instance A, keep instance B
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}
|
||||
|
||||
task_a = _make_task(instance_id_a, TaskStatus.Running)
|
||||
task_b = _make_task(instance_id_b, TaskStatus.Running)
|
||||
tasks = {task_a.task_id: task_a, task_b.task_id: task_b}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only task_a should be cancelled
|
||||
cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]
|
||||
delete_events = [e for e in events if isinstance(e, InstanceDeleted)]
|
||||
assert len(cancel_events) == 1
|
||||
assert cancel_events[0].task_id == task_a.task_id
|
||||
assert cancel_events[0].task_status == TaskStatus.Cancelled
|
||||
assert len(delete_events) == 1
|
||||
assert delete_events[0].instance_id == instance_id_a
|
||||
|
||||
@@ -90,6 +90,7 @@ class ModelCard(CamelCaseModel):
|
||||
base_model: str = ""
|
||||
capabilities: list[str] = []
|
||||
uses_cfg: bool = False
|
||||
trust_remote_code: bool = True
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -137,6 +138,7 @@ class ModelCard(CamelCaseModel):
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
supports_tensor=config_data.supports_tensor,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
trust_remote_code=False,
|
||||
)
|
||||
await mc.save_to_custom_dir()
|
||||
_card_cache[model_id] = mc
|
||||
|
||||
@@ -13,5 +13,6 @@ KV_CACHE_BITS: int | None = None
|
||||
|
||||
DEFAULT_TOP_LOGPROBS: int = 5
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
# True for built-in models with known model cards; custom models added via API default to False
|
||||
# and can be overridden with the --trust-remote-code CLI flag.
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -23,9 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE
|
||||
|
||||
try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
@@ -293,7 +291,15 @@ def shard_and_load(
|
||||
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
trust_remote_code = (
|
||||
shard_metadata.model_card.trust_remote_code
|
||||
or os.environ.get("EXO_TRUST_REMOTE_CODE") == "1"
|
||||
)
|
||||
return load_tokenizer_for_model_id(
|
||||
shard_metadata.model_card.model_id,
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
@@ -325,7 +331,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
|
||||
|
||||
def load_tokenizer_for_model_id(
|
||||
model_id: ModelId, model_path: Path
|
||||
model_id: ModelId, model_path: Path, *, trust_remote_code: bool = TRUST_REMOTE_CODE
|
||||
) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
@@ -394,7 +400,7 @@ def load_tokenizer_for_model_id(
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
tokenizer_config_extra={"trust_remote_code": trust_remote_code},
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
||||
|
||||
125
tmp/test_trust_remote_code_attack.sh
Executable file
125
tmp/test_trust_remote_code_attack.sh
Executable file
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env bash
|
||||
# Test that models added via API get trust_remote_code=false
|
||||
# Run this against a running exo instance.
|
||||
# Usage: ./test_trust_remote_code_attack.sh [host:port]
|
||||
|
||||
set -uo pipefail
|
||||
|
||||
HOST="${1:-localhost:52415}"
|
||||
MODEL_ID="KevTheHermit/security-testing"
|
||||
CUSTOM_CARDS_DIR="$HOME/.exo/custom_model_cards"
|
||||
CARD_FILE="$CUSTOM_CARDS_DIR/KevTheHermit--security-testing.toml"
|
||||
|
||||
echo "=== Test: trust_remote_code attack via API ==="
|
||||
echo "Target: $HOST"
|
||||
echo ""
|
||||
|
||||
# Clean up RCE proof from previous runs
|
||||
rm -f /tmp/exo-rce-proof.txt
|
||||
|
||||
# Step 0: Clean up any stale card from previous runs
|
||||
if [ -f "$CARD_FILE" ]; then
|
||||
echo "[0] Removing stale card from previous run ..."
|
||||
curl -s -X DELETE \
|
||||
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
|
||||
rm -f "$CARD_FILE"
|
||||
echo " Done"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Step 1: Add the malicious model via API
|
||||
echo "[1] Adding model via POST /models/add ..."
|
||||
ADD_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/models/add" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"model_id\":\"$MODEL_ID\"}")
|
||||
HTTP_CODE=$(echo "$ADD_RESPONSE" | tail -1)
|
||||
BODY=$(echo "$ADD_RESPONSE" | sed '$d')
|
||||
echo " HTTP $HTTP_CODE"
|
||||
|
||||
if [ "$HTTP_CODE" -ge 400 ]; then
|
||||
echo " Model add failed (HTTP $HTTP_CODE) — that's fine if model doesn't exist on HF."
|
||||
echo " Response: $BODY"
|
||||
echo ""
|
||||
echo "RESULT: Model was rejected at add time. Attack blocked."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Step 2: Verify the saved TOML has trust_remote_code = false
|
||||
echo ""
|
||||
echo "[2] Checking saved model card TOML ..."
|
||||
if [ ! -f "$CARD_FILE" ]; then
|
||||
echo " FAIL: Card file not found at $CARD_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if grep -q 'trust_remote_code = false' "$CARD_FILE"; then
|
||||
echo " SAFE: trust_remote_code = false (fix is active)"
|
||||
else
|
||||
echo " VULNERABLE: trust_remote_code is not false — remote code WILL be trusted"
|
||||
fi
|
||||
echo " Contents:"
|
||||
cat "$CARD_FILE"
|
||||
|
||||
# Step 3: Place the instance
|
||||
echo ""
|
||||
echo "[3] Attempting POST /place_instance ..."
|
||||
PLACE_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/place_instance" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"model_id\":\"$MODEL_ID\"}")
|
||||
PLACE_CODE=$(echo "$PLACE_RESPONSE" | tail -1)
|
||||
PLACE_BODY=$(echo "$PLACE_RESPONSE" | sed '$d')
|
||||
echo " HTTP $PLACE_CODE"
|
||||
echo " Response: $PLACE_BODY"
|
||||
|
||||
# Step 3b: Send a chat completion to actually trigger tokenizer loading
|
||||
echo ""
|
||||
echo "[3b] Sending chat completion to trigger tokenizer load ..."
|
||||
CHAT_RESPONSE=$(curl -s -w "\n%{http_code}" --max-time 30 -X POST "http://$HOST/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"model\":\"$MODEL_ID\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}],\"max_tokens\":1}")
|
||||
CHAT_CODE=$(echo "$CHAT_RESPONSE" | tail -1)
|
||||
CHAT_BODY=$(echo "$CHAT_RESPONSE" | sed '$d')
|
||||
echo " HTTP $CHAT_CODE"
|
||||
echo " Response: $CHAT_BODY"
|
||||
echo ""
|
||||
echo "[3c] Checking for RCE proof ..."
|
||||
sleep 5
|
||||
if [ -f /tmp/exo-rce-proof.txt ]; then
|
||||
echo " VULNERABLE: Remote code executed!"
|
||||
echo " Contents:"
|
||||
cat /tmp/exo-rce-proof.txt
|
||||
else
|
||||
echo " SAFE: /tmp/exo-rce-proof.txt does not exist — remote code was NOT executed"
|
||||
fi
|
||||
|
||||
# Step 4: Clean up — delete instance and custom model
|
||||
echo ""
|
||||
echo "[4] Cleaning up ..."
|
||||
|
||||
# Find and delete any instance for this model
|
||||
INSTANCE_ID=$(curl -s "http://$HOST/state" | python3 -c "
|
||||
import sys, json
|
||||
state = json.load(sys.stdin)
|
||||
for iid, wrapper in state.get('instances', {}).items():
|
||||
for tag, inst in wrapper.items():
|
||||
sa = inst.get('shardAssignments', {})
|
||||
if sa.get('modelId', '') == '$MODEL_ID':
|
||||
print(iid)
|
||||
sys.exit(0)
|
||||
" 2>/dev/null || true)
|
||||
|
||||
if [ -n "$INSTANCE_ID" ]; then
|
||||
echo " Deleting instance $INSTANCE_ID ..."
|
||||
curl -s -X DELETE "http://$HOST/instance/$INSTANCE_ID" >/dev/null
|
||||
echo " Done"
|
||||
else
|
||||
echo " No instance found to delete"
|
||||
fi
|
||||
|
||||
echo " Deleting custom model card ..."
|
||||
curl -s -X DELETE \
|
||||
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
|
||||
echo " Done"
|
||||
|
||||
echo ""
|
||||
echo "=== DONE ==="
|
||||
Reference in New Issue
Block a user