mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 02:38:40 -05:00
Compare commits
10 Commits
fix-instan
...
event-rout
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
524c3998e4 | ||
|
|
b65982ddd7 | ||
|
|
2fe689315b | ||
|
|
644c5573ce | ||
|
|
12c3015f52 | ||
|
|
365dd68d9a | ||
|
|
d3d129581e | ||
|
|
c90a0cec78 | ||
|
|
e8c1337168 | ||
|
|
7024ddcf3e |
@@ -75,7 +75,7 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
model_id,
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken"],
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@
|
||||
{#if showHome}
|
||||
<button
|
||||
onclick={handleHome}
|
||||
class="text-sm text-exo-light-gray hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-2 cursor-pointer"
|
||||
class="text-sm text-white/70 hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-2 cursor-pointer"
|
||||
title="Back to topology view"
|
||||
>
|
||||
<svg
|
||||
@@ -116,7 +116,7 @@
|
||||
{/if}
|
||||
<a
|
||||
href="/#/downloads"
|
||||
class="text-sm text-exo-light-gray hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-2 cursor-pointer"
|
||||
class="text-sm text-white/70 hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-2 cursor-pointer"
|
||||
title="View downloads overview"
|
||||
>
|
||||
{#if downloadProgress}
|
||||
|
||||
@@ -412,7 +412,7 @@
|
||||
<div>{col.label}</div>
|
||||
{#if col.diskAvailable != null}
|
||||
<div
|
||||
class="text-[9px] text-exo-light-gray/60 normal-case tracking-normal mt-0.5"
|
||||
class="text-[9px] text-white/70 normal-case tracking-normal mt-0.5"
|
||||
>
|
||||
{formatBytes(col.diskAvailable)} free
|
||||
</div>
|
||||
@@ -436,7 +436,7 @@
|
||||
</div>
|
||||
{#if row.prettyName}
|
||||
<div
|
||||
class="text-[10px] text-exo-light-gray/60"
|
||||
class="text-[10px] text-white/60"
|
||||
title={row.modelId}
|
||||
>
|
||||
{row.modelId}
|
||||
@@ -450,7 +450,7 @@
|
||||
title="View model details"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4 text-white/30 hover:text-white/60"
|
||||
class="w-4 h-4 text-white/60 hover:text-white/80"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
>
|
||||
@@ -469,11 +469,11 @@
|
||||
<td class="px-4 py-3 text-center align-middle">
|
||||
{#if cell.kind === "completed"}
|
||||
<div
|
||||
class="flex flex-col items-center gap-0.5"
|
||||
class="flex flex-col items-center gap-1"
|
||||
title="Completed ({formatBytes(cell.totalBytes)})"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-green-400"
|
||||
class="w-7 h-7 text-green-400"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
>
|
||||
@@ -483,18 +483,18 @@
|
||||
clip-rule="evenodd"
|
||||
></path>
|
||||
</svg>
|
||||
<span class="text-[10px] text-exo-light-gray/70"
|
||||
<span class="text-xs text-white/70"
|
||||
>{formatBytes(cell.totalBytes)}</span
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5"
|
||||
class="text-white/50 hover:text-red-400 transition-colors mt-0.5 cursor-pointer"
|
||||
onclick={() =>
|
||||
deleteDownload(col.nodeId, row.modelId)}
|
||||
title="Delete from this node"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
class="w-5 h-5"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
@@ -517,11 +517,11 @@
|
||||
cell.speed,
|
||||
)} - ETA {formatEta(cell.etaMs)}"
|
||||
>
|
||||
<span class="text-exo-yellow text-xs font-medium"
|
||||
<span class="text-exo-yellow text-sm font-medium"
|
||||
>{clampPercent(cell.percentage).toFixed(1)}%</span
|
||||
>
|
||||
<div
|
||||
class="w-14 h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
|
||||
class="w-16 h-2 bg-exo-black/60 rounded-sm overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
@@ -530,25 +530,25 @@
|
||||
).toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<span class="text-[9px] text-exo-light-gray/60"
|
||||
<span class="text-[10px] text-white/70"
|
||||
>{formatSpeed(cell.speed)}</span
|
||||
>
|
||||
</div>
|
||||
{:else if cell.kind === "pending"}
|
||||
<div
|
||||
class="flex flex-col items-center gap-0.5"
|
||||
class="flex flex-col items-center gap-1"
|
||||
title={cell.downloaded > 0
|
||||
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded`
|
||||
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded (paused)`
|
||||
: "Download pending"}
|
||||
>
|
||||
{#if cell.downloaded > 0 && cell.total > 0}
|
||||
<span class="text-exo-light-gray/70 text-[10px]"
|
||||
<span class="text-white/70 text-xs"
|
||||
>{formatBytes(cell.downloaded)} / {formatBytes(
|
||||
cell.total,
|
||||
)}</span
|
||||
>
|
||||
<div
|
||||
class="w-full h-1 bg-white/10 rounded-full overflow-hidden"
|
||||
class="w-full h-1.5 bg-white/10 rounded-full overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-exo-light-gray/40 rounded-full"
|
||||
@@ -558,21 +558,65 @@
|
||||
).toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<span class="text-exo-light-gray/40 text-[9px]"
|
||||
>paused</span
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Resume download on this node"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{:else}
|
||||
<span class="text-white/50 text-[10px]">paused</span
|
||||
>
|
||||
{/if}
|
||||
{:else if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Start download on this node"
|
||||
>
|
||||
<svg
|
||||
class="w-6 h-6"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{:else}
|
||||
<span class="text-exo-light-gray/50 text-sm">...</span
|
||||
>
|
||||
<span class="text-white/40 text-sm">...</span>
|
||||
{/if}
|
||||
</div>
|
||||
{:else if cell.kind === "failed"}
|
||||
<div
|
||||
class="flex flex-col items-center gap-0.5"
|
||||
class="flex flex-col items-center gap-1"
|
||||
title="Download failed"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-red-400"
|
||||
class="w-7 h-7 text-red-400"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
>
|
||||
@@ -585,13 +629,13 @@
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Retry download on this node"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
class="w-5 h-5"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
@@ -617,13 +661,13 @@
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100 cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Download to this node"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
class="w-5 h-5"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from random import random
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
@@ -22,13 +21,9 @@ from exo.shared.types.commands import (
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
# TODO(evan): just for acks, should delete this ASAP
|
||||
GlobalForwarderEvent,
|
||||
LocalForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
@@ -39,40 +34,28 @@ from exo.shared.types.worker.downloads import (
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadCoordinator:
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[LocalForwarderEvent]
|
||||
|
||||
# ack stuff
|
||||
_global_event_receiver: Receiver[GlobalForwarderEvent]
|
||||
_out_for_delivery: dict[EventId, LocalForwarderEvent] = field(default_factory=dict)
|
||||
|
||||
event_sender: Sender[Event]
|
||||
offline: bool = False
|
||||
|
||||
_system_id: SystemId = field(default_factory=SystemId)
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
|
||||
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
|
||||
|
||||
# Per-model throttle for download progress events
|
||||
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
if self.offline:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
@@ -128,10 +111,7 @@ class DownloadCoordinator:
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._clear_ofd)
|
||||
if not self.offline:
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
finally:
|
||||
@@ -169,20 +149,6 @@ class DownloadCoordinator:
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
# directly copied from worker
|
||||
async def _resend_out_for_delivery(self) -> None:
|
||||
# This can also be massively tightened, we should check events are at least a certain age before resending.
|
||||
# Exponential backoff would also certainly help here.
|
||||
while True:
|
||||
await anyio.sleep(1 + random())
|
||||
for event in self._out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
async def _clear_ofd(self) -> None:
|
||||
with self._global_event_receiver as events:
|
||||
async for event in events:
|
||||
self._out_for_delivery.pop(event.event.event_id, None)
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
with self.download_command_receiver as commands:
|
||||
async for cmd in commands:
|
||||
@@ -355,23 +321,6 @@ class DownloadCoordinator:
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
idx = 0
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
fe = LocalForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self._system_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
idx += 1
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
await self.local_event_sender.send(fe)
|
||||
self._out_for_delivery[event.event_id] = fe
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -823,6 +823,7 @@ async def download_shard(
|
||||
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
|
||||
final_file_exists = await aios.path.exists(target_dir / file.path)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=shard.model_card.model_id,
|
||||
repo_revision=revision,
|
||||
@@ -832,7 +833,9 @@ async def download_shard(
|
||||
total=Memory.from_bytes(file.size or 0),
|
||||
speed=0,
|
||||
eta=timedelta(0),
|
||||
status="complete" if downloaded_bytes == file.size else "not_started",
|
||||
status="complete"
|
||||
if final_file_exists and downloaded_bytes == file.size
|
||||
else "not_started",
|
||||
start_time=time.time(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.shard_downloader import NoopShardDownloader
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
GlobalForwarderEvent,
|
||||
LocalForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadPending,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.utils.channels import channel
|
||||
|
||||
# Use the built‑in NoopShardDownloader directly – it already implements the required abstract interface.
|
||||
# No additional subclass is needed for this test.
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ack_behaviour():
|
||||
# Create channels (type Any for simplicity)
|
||||
_, command_receiver = channel[Any]()
|
||||
local_sender, _ = channel[Any]()
|
||||
global_sender, global_receiver = channel[Any]()
|
||||
|
||||
# Minimal identifiers
|
||||
node_id = NodeId()
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
# Create a dummy model card and shard metadata
|
||||
model_id = ModelId("test/model")
|
||||
model_card = ModelCard(
|
||||
model_id=model_id,
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
shard = PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
|
||||
# Instantiate the coordinator with the dummy downloader
|
||||
coord = DownloadCoordinator(
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
shard_downloader=NoopShardDownloader(),
|
||||
download_command_receiver=command_receiver,
|
||||
local_event_sender=local_sender,
|
||||
_global_event_receiver=global_receiver,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Start the forwarding and ack‑clearing loops
|
||||
tg.start_soon(coord._forward_events) # pyright: ignore[reportPrivateUsage]
|
||||
tg.start_soon(coord._clear_ofd) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Send a pending download progress event via the internal event sender
|
||||
pending = DownloadPending(
|
||||
node_id=node_id,
|
||||
shard_metadata=shard,
|
||||
model_directory="/tmp/model",
|
||||
)
|
||||
await coord.event_sender.send(NodeDownloadProgress(download_progress=pending))
|
||||
# Allow the forwarder to process the event
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# There should be exactly one entry awaiting ACK
|
||||
assert len(coord._out_for_delivery) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
# Retrieve the stored LocalForwarderEvent
|
||||
stored_fe: LocalForwarderEvent = next(iter(coord._out_for_delivery.values())) # pyright: ignore[reportPrivateUsage]
|
||||
# Simulate receiving a global ack for this event
|
||||
ack = GlobalForwarderEvent(
|
||||
origin_idx=0,
|
||||
origin=node_id,
|
||||
session=session_id,
|
||||
event=stored_fe.event,
|
||||
)
|
||||
await global_sender.send(ack)
|
||||
# Give the clear‑ofd task a moment to process the ack
|
||||
await anyio.sleep(0.1)
|
||||
# The out‑for‑delivery map should now be empty
|
||||
assert len(coord._out_for_delivery) == 0 # pyright: ignore[reportPrivateUsage]
|
||||
# Cancel background tasks
|
||||
tg.cancel_scope.cancel()
|
||||
@@ -15,6 +15,7 @@ from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
from exo.routing.event_router import EventRouter
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
from exo.shared.constants import EXO_LOG
|
||||
from exo.shared.election import Election, ElectionResult
|
||||
@@ -29,6 +30,7 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
event_router: EventRouter
|
||||
download_coordinator: DownloadCoordinator | None
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
@@ -52,6 +54,12 @@ class Node:
|
||||
await router.register_topic(topics.ELECTION_MESSAGES)
|
||||
await router.register_topic(topics.CONNECTION_MESSAGES)
|
||||
await router.register_topic(topics.DOWNLOAD_COMMANDS)
|
||||
event_router = EventRouter(
|
||||
session_id,
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
external_outbound=router.sender(topics.LOCAL_EVENTS),
|
||||
external_inbound=router.receiver(topics.GLOBAL_EVENTS),
|
||||
)
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
@@ -59,13 +67,10 @@ class Node:
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
event_sender=event_router.sender(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
offline=args.offline,
|
||||
# TODO(evan): remove
|
||||
_global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
@@ -73,9 +78,8 @@ class Node:
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
session_id,
|
||||
port=args.api_port,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
event_receiver=event_router.receiver(),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
@@ -86,9 +90,8 @@ class Node:
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_receiver=event_router.receiver(),
|
||||
event_sender=event_router.sender(),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
)
|
||||
@@ -99,6 +102,7 @@ class Node:
|
||||
master = Master(
|
||||
node_id,
|
||||
session_id,
|
||||
event_sender=event_router.sender(),
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
@@ -121,6 +125,7 @@ class Node:
|
||||
|
||||
return cls(
|
||||
router,
|
||||
event_router,
|
||||
download_coordinator,
|
||||
worker,
|
||||
election,
|
||||
@@ -136,6 +141,7 @@ class Node:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.event_router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
tg.start_soon(self.download_coordinator.run)
|
||||
@@ -183,6 +189,7 @@ class Node:
|
||||
self.master = Master(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
event_sender=self.event_router.sender(),
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
@@ -206,21 +213,24 @@ class Node:
|
||||
)
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
self.event_router.shutdown()
|
||||
self.event_router = EventRouter(
|
||||
result.session_id,
|
||||
self.router.sender(topics.COMMANDS),
|
||||
self.router.receiver(topics.GLOBAL_EVENTS),
|
||||
self.router.sender(topics.LOCAL_EVENTS),
|
||||
)
|
||||
self._tg.start_soon(self.event_router.run)
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
event_sender=self.event_router.sender(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
offline=self.offline,
|
||||
# TODO(evan): remove
|
||||
_global_event_receiver=self.router.receiver(
|
||||
topics.GLOBAL_EVENTS
|
||||
),
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
@@ -228,11 +238,8 @@ class Node:
|
||||
# TODO: add profiling etc to resource monitor
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
global_event_receiver=self.router.receiver(
|
||||
topics.GLOBAL_EVENTS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_receiver=self.event_router.receiver(),
|
||||
event_sender=self.event_router.sender(),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
@@ -240,7 +247,7 @@ class Node:
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
self.api.reset(result.session_id, result.won_clock)
|
||||
self.api.reset(result.won_clock, self.event_router.receiver())
|
||||
else:
|
||||
if self.api:
|
||||
self.api.unpause(result.won_clock)
|
||||
@@ -252,7 +259,7 @@ def main():
|
||||
target = min(max(soft, 65535), hard)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
mp.set_start_method("spawn", force=True)
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
|
||||
@@ -140,11 +140,10 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
TracesMerged,
|
||||
)
|
||||
@@ -168,16 +167,10 @@ from exo.shared.types.openai_responses import (
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxJacclInstance,
|
||||
)
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
_API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api"
|
||||
@@ -201,10 +194,9 @@ class API:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
*,
|
||||
port: int,
|
||||
global_event_receiver: Receiver[GlobalForwarderEvent],
|
||||
event_receiver: Receiver[IndexedEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
@@ -215,11 +207,9 @@ class API:
|
||||
self._system_id = SystemId()
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.event_receiver = event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
self.last_completed_election: int = 0
|
||||
self.port = port
|
||||
|
||||
@@ -259,17 +249,18 @@ class API:
|
||||
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]):
|
||||
logger.info("Resetting API State")
|
||||
self._event_log.close()
|
||||
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
||||
self.state = State()
|
||||
self._system_id = SystemId()
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._text_generation_queues = {}
|
||||
self._image_generation_queues = {}
|
||||
self.unpause(result_clock)
|
||||
self.event_receiver.close()
|
||||
self.event_receiver = event_receiver
|
||||
self._tg.start_soon(self._apply_state)
|
||||
|
||||
def unpause(self, result_clock: int):
|
||||
logger.info("Unpausing API")
|
||||
@@ -518,14 +509,6 @@ class API:
|
||||
shard_assignments = instance.shard_assignments
|
||||
placement_node_ids = list(shard_assignments.node_to_runner.keys())
|
||||
|
||||
# Derive instance_meta from the actual instance type, since
|
||||
# place_instance() may override it (e.g., single-node → MlxRing)
|
||||
actual_instance_meta = (
|
||||
InstanceMeta.MlxJaccl
|
||||
if isinstance(instance, MlxJacclInstance)
|
||||
else InstanceMeta.MlxRing
|
||||
)
|
||||
|
||||
memory_delta_by_node: dict[str, int] = {}
|
||||
if placement_node_ids:
|
||||
total_bytes = model_card.storage_size.in_bytes
|
||||
@@ -538,14 +521,14 @@ class API:
|
||||
if (
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
actual_instance_meta,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=actual_instance_meta,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
@@ -555,7 +538,7 @@ class API:
|
||||
(
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
actual_instance_meta,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
)
|
||||
)
|
||||
@@ -1619,7 +1602,7 @@ class API:
|
||||
finally:
|
||||
self._event_log.close()
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
self.event_receiver.close()
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
@@ -1636,38 +1619,31 @@ class API:
|
||||
)
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.session != self.session_id:
|
||||
continue
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
continue
|
||||
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
with self.event_receiver as events:
|
||||
async for i_event in events:
|
||||
self._event_log.append(i_event.event)
|
||||
self.state = apply(self.state, i_event)
|
||||
event = i_event.event
|
||||
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if queue := self._image_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert not isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if queue := self._image_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(event.command_id, None)
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert not isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
def _save_merged_trace(self, event: TracesMerged) -> None:
|
||||
traces = [
|
||||
|
||||
@@ -60,7 +60,7 @@ from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
@@ -72,25 +72,21 @@ class Master:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
event_sender: Sender[Event],
|
||||
local_event_receiver: Receiver[LocalForwarderEvent],
|
||||
global_event_sender: Sender[GlobalForwarderEvent],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
self.node_id = node_id
|
||||
self.session_id = session_id
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
self.command_task_mapping: dict[CommandId, TaskId] = {}
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self.event_sender = event_sender
|
||||
self._system_id = SystemId()
|
||||
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
@@ -104,15 +100,12 @@ class Master:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self._event_log.close()
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -409,22 +402,6 @@ class Master:
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
local_index = 0
|
||||
with self._loopback_event_receiver as events:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
LocalForwarderEvent(
|
||||
origin=self._system_id,
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
local_index += 1
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
# Convenience method since this line is ugly
|
||||
|
||||
@@ -17,6 +17,7 @@ from exo.shared.types.commands import (
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
@@ -50,6 +51,22 @@ async def test_master():
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[LocalForwarderEvent]()
|
||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||
ev_send, ev_recv = channel[Event]()
|
||||
|
||||
async def mock_event_router():
|
||||
idx = 0
|
||||
sid = SystemId()
|
||||
with ev_recv as master_events:
|
||||
async for event in master_events:
|
||||
await local_event_sender.send(
|
||||
LocalForwarderEvent(
|
||||
origin=sid,
|
||||
origin_idx=idx,
|
||||
session=session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -67,6 +84,7 @@ async def test_master():
|
||||
master = Master(
|
||||
node_id,
|
||||
session_id,
|
||||
event_sender=ev_send,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
@@ -75,6 +93,7 @@ async def test_master():
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
tg.start_soon(mock_event_router)
|
||||
|
||||
# inject a NodeGatheredInfo event
|
||||
logger.info("inject a NodeGatheredInfo event")
|
||||
@@ -197,4 +216,5 @@ async def test_master():
|
||||
input=[InputMessage(role="user", content="Hello, how are you?")],
|
||||
)
|
||||
|
||||
ev_send.close()
|
||||
await master.shutdown()
|
||||
|
||||
@@ -14,10 +14,12 @@ from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted, TaskStatusUpdated
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -456,3 +458,117 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
else:
|
||||
ip_part = coordinator.split(":")[0]
|
||||
assert len(ip_part.split(".")) == 4
|
||||
|
||||
|
||||
def _make_task(
|
||||
instance_id: InstanceId,
|
||||
status: TaskStatus = TaskStatus.Running,
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(),
|
||||
task_status=status,
|
||||
instance_id=instance_id,
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=ModelId("test-model"),
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_running_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Running)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – cancellation event should come before the deletion event
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
assert events[1].instance_id == instance_id
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_pending_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Pending)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_ignores_completed_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
tasks = {
|
||||
t.task_id: t
|
||||
for t in [
|
||||
_make_task(instance_id, TaskStatus.Complete),
|
||||
_make_task(instance_id, TaskStatus.Failed),
|
||||
_make_task(instance_id, TaskStatus.TimedOut),
|
||||
_make_task(instance_id, TaskStatus.Cancelled),
|
||||
]
|
||||
}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only the InstanceDeleted event, no cancellations
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_only_matching_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id_a = InstanceId()
|
||||
instance_id_b = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {
|
||||
instance_id_a: instance,
|
||||
instance_id_b: instance,
|
||||
}
|
||||
# only delete instance A, keep instance B
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}
|
||||
|
||||
task_a = _make_task(instance_id_a, TaskStatus.Running)
|
||||
task_b = _make_task(instance_id_b, TaskStatus.Running)
|
||||
tasks = {task_a.task_id: task_a, task_b.task_id: task_b}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only task_a should be cancelled
|
||||
cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]
|
||||
delete_events = [e for e in events if isinstance(e, InstanceDeleted)]
|
||||
assert len(cancel_events) == 1
|
||||
assert cancel_events[0].task_id == task_a.task_id
|
||||
assert cancel_events[0].task_status == TaskStatus.Cancelled
|
||||
assert len(delete_events) == 1
|
||||
assert delete_events[0].instance_id == instance_id_a
|
||||
|
||||
161
src/exo/routing/event_router.py
Normal file
161
src/exo/routing/event_router.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from dataclasses import dataclass, field
|
||||
from random import random
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, ClosedResourceError
|
||||
from anyio.abc import CancelScope
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
LocalForwarderEvent,
|
||||
)
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventRouter:
|
||||
session_id: SessionId
|
||||
command_sender: Sender[ForwarderCommand]
|
||||
external_inbound: Receiver[GlobalForwarderEvent]
|
||||
external_outbound: Sender[LocalForwarderEvent]
|
||||
_system_id: SystemId = field(init=False, default_factory=SystemId)
|
||||
internal_outbound: list[Sender[IndexedEvent]] = field(
|
||||
init=False, default_factory=list
|
||||
)
|
||||
event_buffer: OrderedBuffer[Event] = field(
|
||||
init=False, default_factory=OrderedBuffer
|
||||
)
|
||||
out_for_delivery: dict[EventId, tuple[float, LocalForwarderEvent]] = field(
|
||||
init=False, default_factory=dict
|
||||
)
|
||||
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
|
||||
|
||||
_nack_cancel_scope: CancelScope | None = field(init=False, default=None)
|
||||
_nack_attempts: int = field(init=False, default=0)
|
||||
_nack_base_seconds: float = field(init=False, default=0.5)
|
||||
_nack_cap_seconds: float = field(init=False, default=10.0)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._run_ext_in)
|
||||
tg.start_soon(self._simple_retry)
|
||||
finally:
|
||||
self.external_outbound.close()
|
||||
for send in self.internal_outbound:
|
||||
send.close()
|
||||
|
||||
# can make this better in future
|
||||
async def _simple_retry(self):
|
||||
while True:
|
||||
await anyio.sleep(1 + random())
|
||||
# list here is a shallow clone for shared mutation
|
||||
for e_id, (time, event) in list(self.out_for_delivery.items()):
|
||||
if anyio.current_time() > time + 5:
|
||||
self.out_for_delivery[e_id] = (anyio.current_time(), event)
|
||||
await self.external_outbound.send(event)
|
||||
|
||||
def sender(self) -> Sender[Event]:
|
||||
send, recv = channel[Event]()
|
||||
if self._tg.is_running():
|
||||
self._tg.start_soon(self._ingest, SystemId(), recv)
|
||||
else:
|
||||
self._tg.queue(self._ingest, SystemId(), recv)
|
||||
return send
|
||||
|
||||
def receiver(self) -> Receiver[IndexedEvent]:
|
||||
send, recv = channel[IndexedEvent]()
|
||||
self.internal_outbound.append(send)
|
||||
return recv
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
async def _ingest(self, system_id: SystemId, recv: Receiver[Event]):
|
||||
idx = 0
|
||||
with recv as events:
|
||||
async for event in events:
|
||||
f_ev = LocalForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=system_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
idx += 1
|
||||
await self.external_outbound.send(f_ev)
|
||||
self.out_for_delivery[event.event_id] = (anyio.current_time(), f_ev)
|
||||
|
||||
async def _run_ext_in(self):
|
||||
buf = OrderedBuffer[Event]()
|
||||
with self.external_inbound as events:
|
||||
async for event in events:
|
||||
if event.session != self.session_id:
|
||||
continue
|
||||
if event.origin != self.session_id.master_node_id:
|
||||
continue
|
||||
|
||||
buf.ingest(event.origin_idx, event.event)
|
||||
event_id = event.event.event_id
|
||||
if event_id in self.out_for_delivery:
|
||||
self.out_for_delivery.pop(event_id)
|
||||
|
||||
drained = buf.drain_indexed()
|
||||
if drained:
|
||||
self._nack_attempts = 0
|
||||
if self._nack_cancel_scope:
|
||||
self._nack_cancel_scope.cancel()
|
||||
|
||||
if not drained and (
|
||||
self._nack_cancel_scope is None
|
||||
or self._nack_cancel_scope.cancel_called
|
||||
):
|
||||
# Request the next index.
|
||||
self._tg.start_soon(self._nack_request, buf.next_idx_to_release)
|
||||
continue
|
||||
|
||||
for idx, event in drained:
|
||||
to_clear = set[int]()
|
||||
for i, sender in enumerate(self.internal_outbound):
|
||||
try:
|
||||
await sender.send(IndexedEvent(idx=idx, event=event))
|
||||
except (ClosedResourceError, BrokenResourceError):
|
||||
to_clear.add(i)
|
||||
for i in sorted(to_clear, reverse=True):
|
||||
self.internal_outbound.pop(i)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
# This function is started whenever we receive an event that is out of sequence.
|
||||
# It is cancelled as soon as we receiver an event that is in sequence.
|
||||
|
||||
if since_idx < 0:
|
||||
logger.warning(f"Negative value encountered for nack request {since_idx=}")
|
||||
since_idx = 0
|
||||
|
||||
with CancelScope() as scope:
|
||||
self._nack_cancel_scope = scope
|
||||
delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)
|
||||
delay = min(self._nack_cap_seconds, delay)
|
||||
self._nack_attempts += 1
|
||||
try:
|
||||
await anyio.sleep(delay)
|
||||
logger.info(
|
||||
f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}"
|
||||
)
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self._system_id,
|
||||
command=RequestEventLog(since_idx=since_idx),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if self._nack_cancel_scope is scope:
|
||||
self._nack_cancel_scope = None
|
||||
@@ -90,6 +90,7 @@ class ModelCard(CamelCaseModel):
|
||||
base_model: str = ""
|
||||
capabilities: list[str] = []
|
||||
uses_cfg: bool = False
|
||||
trust_remote_code: bool = True
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -137,6 +138,7 @@ class ModelCard(CamelCaseModel):
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
supports_tensor=config_data.supports_tensor,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
trust_remote_code=False,
|
||||
)
|
||||
await mc.save_to_custom_dir()
|
||||
_card_cache[model_id] = mc
|
||||
|
||||
@@ -852,6 +852,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if hasattr(layer, "linear_attn"):
|
||||
|
||||
@@ -23,9 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE
|
||||
|
||||
try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
@@ -293,7 +291,11 @@ def shard_and_load(
|
||||
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
return load_tokenizer_for_model_id(
|
||||
shard_metadata.model_card.model_id,
|
||||
model_path,
|
||||
trust_remote_code=shard_metadata.model_card.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
@@ -325,7 +327,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
|
||||
|
||||
def load_tokenizer_for_model_id(
|
||||
model_id: ModelId, model_path: Path
|
||||
model_id: ModelId, model_path: Path, *, trust_remote_code: bool = TRUST_REMOTE_CODE
|
||||
) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
@@ -394,7 +396,7 @@ def load_tokenizer_for_model_id(
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
tokenizer_config_extra={"trust_remote_code": trust_remote_code},
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from random import random
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, fail_after
|
||||
from anyio import fail_after
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import resolve_model_in_path
|
||||
@@ -13,17 +12,13 @@ from exo.shared.types.api import ImageEditsTaskParams
|
||||
from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
RequestEventLog,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import CommandId, NodeId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
LocalForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
@@ -46,7 +41,6 @@ from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.utils.keyed_backoff import KeyedBackoff
|
||||
@@ -59,38 +53,26 @@ class Worker:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
*,
|
||||
global_event_receiver: Receiver[GlobalForwarderEvent],
|
||||
local_event_sender: Sender[LocalForwarderEvent],
|
||||
event_receiver: Receiver[IndexedEvent],
|
||||
event_sender: Sender[Event],
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
self.event_receiver = event_receiver
|
||||
self.event_sender = event_sender
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
|
||||
self._nack_cancel_scope: CancelScope | None = None
|
||||
self._nack_attempts: int = 0
|
||||
self._nack_base_seconds: float = 0.5
|
||||
self._nack_cap_seconds: float = 10.0
|
||||
|
||||
self._system_id = SystemId()
|
||||
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
|
||||
# Buffer for input image chunks (for image editing)
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
@@ -108,14 +90,12 @@ class Worker:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
finally:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
logger.info("Stopping Worker")
|
||||
self.local_event_sender.close()
|
||||
self.event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
@@ -133,47 +113,22 @@ class Worker:
|
||||
)
|
||||
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.session != self.session_id:
|
||||
continue
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
continue
|
||||
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
||||
event_id = f_event.event.event_id
|
||||
if event_id in self.out_for_delivery:
|
||||
del self.out_for_delivery[event_id]
|
||||
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
# 2. for each event, apply it to the state
|
||||
indexed_events = self.event_buffer.drain_indexed()
|
||||
if indexed_events:
|
||||
self._nack_attempts = 0
|
||||
self.state = apply(self.state, event=event)
|
||||
event = event.event
|
||||
|
||||
if not indexed_events and (
|
||||
self._nack_cancel_scope is None
|
||||
or self._nack_cancel_scope.cancel_called
|
||||
):
|
||||
# Request the next index.
|
||||
self._tg.start_soon(
|
||||
self._nack_request, self.state.last_event_applied_idx + 1
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
if cmd_id not in self.input_chunk_buffer:
|
||||
self.input_chunk_buffer[cmd_id] = {}
|
||||
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
|
||||
|
||||
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
|
||||
event.chunk.data
|
||||
)
|
||||
continue
|
||||
elif indexed_events and self._nack_cancel_scope:
|
||||
self._nack_cancel_scope.cancel()
|
||||
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
if cmd_id not in self.input_chunk_buffer:
|
||||
self.input_chunk_buffer[cmd_id] = {}
|
||||
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
|
||||
|
||||
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
|
||||
event.chunk.data
|
||||
)
|
||||
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
@@ -325,43 +280,6 @@ class Worker:
|
||||
instance.shard_assignments.node_to_runner[self.node_id]
|
||||
].start_task(task)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
# This function is started whenever we receive an event that is out of sequence.
|
||||
# It is cancelled as soon as we receiver an event that is in sequence.
|
||||
|
||||
if since_idx < 0:
|
||||
logger.warning(f"Negative value encountered for nack request {since_idx=}")
|
||||
since_idx = 0
|
||||
|
||||
with CancelScope() as scope:
|
||||
self._nack_cancel_scope = scope
|
||||
delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)
|
||||
delay = min(self._nack_cap_seconds, delay)
|
||||
self._nack_attempts += 1
|
||||
try:
|
||||
await anyio.sleep(delay)
|
||||
logger.info(
|
||||
f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}"
|
||||
)
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self._system_id,
|
||||
command=RequestEventLog(since_idx=since_idx),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if self._nack_cancel_scope is scope:
|
||||
self._nack_cancel_scope = None
|
||||
|
||||
async def _resend_out_for_delivery(self) -> None:
|
||||
# This can also be massively tightened, we should check events are at least a certain age before resending.
|
||||
# Exponential backoff would also certainly help here.
|
||||
while True:
|
||||
await anyio.sleep(1 + random())
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
@@ -372,21 +290,6 @@ class Worker:
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
idx = 0
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
fe = LocalForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self._system_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
idx += 1
|
||||
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
||||
await self.local_event_sender.send(fe)
|
||||
self.out_for_delivery[event.event_id] = fe
|
||||
|
||||
async def _poll_connection_updates(self):
|
||||
while True:
|
||||
edges = set(
|
||||
|
||||
@@ -106,13 +106,18 @@ class RunnerSupervisor:
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._tg.cancel_tasks()
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
if not self._cancel_watch_runner.cancel_called:
|
||||
self._cancel_watch_runner.cancel()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._ev_recv.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._task_sender.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._event_sender.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(5)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
|
||||
125
tmp/test_trust_remote_code_attack.sh
Executable file
125
tmp/test_trust_remote_code_attack.sh
Executable file
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env bash
|
||||
# Test that models added via API get trust_remote_code=false
|
||||
# Run this against a running exo instance.
|
||||
# Usage: ./test_trust_remote_code_attack.sh [host:port]
|
||||
|
||||
set -uo pipefail
|
||||
|
||||
HOST="${1:-localhost:52415}"
|
||||
MODEL_ID="KevTheHermit/security-testing"
|
||||
CUSTOM_CARDS_DIR="$HOME/.exo/custom_model_cards"
|
||||
CARD_FILE="$CUSTOM_CARDS_DIR/KevTheHermit--security-testing.toml"
|
||||
|
||||
echo "=== Test: trust_remote_code attack via API ==="
|
||||
echo "Target: $HOST"
|
||||
echo ""
|
||||
|
||||
# Clean up RCE proof from previous runs
|
||||
rm -f /tmp/exo-rce-proof.txt
|
||||
|
||||
# Step 0: Clean up any stale card from previous runs
|
||||
if [ -f "$CARD_FILE" ]; then
|
||||
echo "[0] Removing stale card from previous run ..."
|
||||
curl -s -X DELETE \
|
||||
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
|
||||
rm -f "$CARD_FILE"
|
||||
echo " Done"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Step 1: Add the malicious model via API
|
||||
echo "[1] Adding model via POST /models/add ..."
|
||||
ADD_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/models/add" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"model_id\":\"$MODEL_ID\"}")
|
||||
HTTP_CODE=$(echo "$ADD_RESPONSE" | tail -1)
|
||||
BODY=$(echo "$ADD_RESPONSE" | sed '$d')
|
||||
echo " HTTP $HTTP_CODE"
|
||||
|
||||
if [ "$HTTP_CODE" -ge 400 ]; then
|
||||
echo " Model add failed (HTTP $HTTP_CODE) — that's fine if model doesn't exist on HF."
|
||||
echo " Response: $BODY"
|
||||
echo ""
|
||||
echo "RESULT: Model was rejected at add time. Attack blocked."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Step 2: Verify the saved TOML has trust_remote_code = false
|
||||
echo ""
|
||||
echo "[2] Checking saved model card TOML ..."
|
||||
if [ ! -f "$CARD_FILE" ]; then
|
||||
echo " FAIL: Card file not found at $CARD_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if grep -q 'trust_remote_code = false' "$CARD_FILE"; then
|
||||
echo " SAFE: trust_remote_code = false (fix is active)"
|
||||
else
|
||||
echo " VULNERABLE: trust_remote_code is not false — remote code WILL be trusted"
|
||||
fi
|
||||
echo " Contents:"
|
||||
cat "$CARD_FILE"
|
||||
|
||||
# Step 3: Place the instance
|
||||
echo ""
|
||||
echo "[3] Attempting POST /place_instance ..."
|
||||
PLACE_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/place_instance" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"model_id\":\"$MODEL_ID\"}")
|
||||
PLACE_CODE=$(echo "$PLACE_RESPONSE" | tail -1)
|
||||
PLACE_BODY=$(echo "$PLACE_RESPONSE" | sed '$d')
|
||||
echo " HTTP $PLACE_CODE"
|
||||
echo " Response: $PLACE_BODY"
|
||||
|
||||
# Step 3b: Send a chat completion to actually trigger tokenizer loading
|
||||
echo ""
|
||||
echo "[3b] Sending chat completion to trigger tokenizer load ..."
|
||||
CHAT_RESPONSE=$(curl -s -w "\n%{http_code}" --max-time 30 -X POST "http://$HOST/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"model\":\"$MODEL_ID\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}],\"max_tokens\":1}")
|
||||
CHAT_CODE=$(echo "$CHAT_RESPONSE" | tail -1)
|
||||
CHAT_BODY=$(echo "$CHAT_RESPONSE" | sed '$d')
|
||||
echo " HTTP $CHAT_CODE"
|
||||
echo " Response: $CHAT_BODY"
|
||||
echo ""
|
||||
echo "[3c] Checking for RCE proof ..."
|
||||
sleep 5
|
||||
if [ -f /tmp/exo-rce-proof.txt ]; then
|
||||
echo " VULNERABLE: Remote code executed!"
|
||||
echo " Contents:"
|
||||
cat /tmp/exo-rce-proof.txt
|
||||
else
|
||||
echo " SAFE: /tmp/exo-rce-proof.txt does not exist — remote code was NOT executed"
|
||||
fi
|
||||
|
||||
# Step 4: Clean up — delete instance and custom model
|
||||
echo ""
|
||||
echo "[4] Cleaning up ..."
|
||||
|
||||
# Find and delete any instance for this model
|
||||
INSTANCE_ID=$(curl -s "http://$HOST/state" | python3 -c "
|
||||
import sys, json
|
||||
state = json.load(sys.stdin)
|
||||
for iid, wrapper in state.get('instances', {}).items():
|
||||
for tag, inst in wrapper.items():
|
||||
sa = inst.get('shardAssignments', {})
|
||||
if sa.get('modelId', '') == '$MODEL_ID':
|
||||
print(iid)
|
||||
sys.exit(0)
|
||||
" 2>/dev/null || true)
|
||||
|
||||
if [ -n "$INSTANCE_ID" ]; then
|
||||
echo " Deleting instance $INSTANCE_ID ..."
|
||||
curl -s -X DELETE "http://$HOST/instance/$INSTANCE_ID" >/dev/null
|
||||
echo " Done"
|
||||
else
|
||||
echo " No instance found to delete"
|
||||
fi
|
||||
|
||||
echo " Deleting custom model card ..."
|
||||
curl -s -X DELETE \
|
||||
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
|
||||
echo " Done"
|
||||
|
||||
echo ""
|
||||
echo "=== DONE ==="
|
||||
Reference in New Issue
Block a user