Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
525aaab808 fix 2026-02-26 13:42:07 +00:00
5 changed files with 41 additions and 144 deletions

View File

@@ -3158,23 +3158,6 @@ class AppStore {
return (await response.json()) as TraceStatsResponse;
}
/**
* Delete traces by task IDs
*/
async deleteTraces(
taskIds: string[],
): Promise<{ deleted: string[]; notFound: string[] }> {
const response = await fetch("/v1/traces/delete", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ taskIds }),
});
if (!response.ok) {
throw new Error(`Failed to delete traces: ${response.status}`);
}
return await response.json();
}
/**
* Get the URL for the raw trace file (for Perfetto)
*/
@@ -3318,5 +3301,3 @@ export const fetchTraceStats = (taskId: string) =>
appStore.fetchTraceStats(taskId);
export const getTraceRawUrl = (taskId: string) =>
appStore.getTraceRawUrl(taskId);
export const deleteTraces = (taskIds: string[]) =>
appStore.deleteTraces(taskIds);

View File

@@ -3,7 +3,6 @@
import {
listTraces,
getTraceRawUrl,
deleteTraces,
type TraceListItem,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
@@ -11,51 +10,6 @@
let traces = $state<TraceListItem[]>([]);
let loading = $state(true);
let error = $state<string | null>(null);
let selectedIds = $state<Set<string>>(new Set());
let deleting = $state(false);
let allSelected = $derived(
traces.length > 0 && selectedIds.size === traces.length,
);
function toggleSelect(taskId: string) {
const next = new Set(selectedIds);
if (next.has(taskId)) {
next.delete(taskId);
} else {
next.add(taskId);
}
selectedIds = next;
}
function toggleSelectAll() {
if (allSelected) {
selectedIds = new Set();
} else {
selectedIds = new Set(traces.map((t) => t.taskId));
}
}
async function handleDelete() {
if (selectedIds.size === 0) return;
const count = selectedIds.size;
if (
!confirm(
`Delete ${count} trace${count === 1 ? "" : "s"}? This cannot be undone.`,
)
)
return;
deleting = true;
try {
await deleteTraces([...selectedIds]);
selectedIds = new Set();
await refresh();
} catch (e) {
error = e instanceof Error ? e.message : "Failed to delete traces";
} finally {
deleting = false;
}
}
function formatBytes(bytes: number): string {
if (!bytes || bytes <= 0) return "0B";
@@ -155,16 +109,6 @@
</h1>
</div>
<div class="flex items-center gap-3">
{#if selectedIds.size > 0}
<button
type="button"
class="text-xs font-mono text-red-400 hover:text-red-300 transition-colors uppercase border border-red-500/40 px-2 py-1 rounded"
onclick={handleDelete}
disabled={deleting}
>
{deleting ? "Deleting..." : `Delete (${selectedIds.size})`}
</button>
{/if}
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
@@ -199,41 +143,14 @@
</div>
{:else}
<div class="space-y-3">
<div class="flex items-center gap-2 px-1">
<button
type="button"
class="text-xs font-mono uppercase transition-colors {allSelected
? 'text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
onclick={toggleSelectAll}
>
{allSelected ? "Deselect all" : "Select all"}
</button>
</div>
{#each traces as trace}
{@const isSelected = selectedIds.has(trace.taskId)}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
role="button"
tabindex="0"
class="w-full text-left rounded border-l-2 border-r border-t border-b transition-all p-4 flex items-center justify-between gap-4 cursor-pointer {isSelected
? 'bg-exo-yellow/10 border-l-exo-yellow border-r-exo-medium-gray/30 border-t-exo-medium-gray/30 border-b-exo-medium-gray/30'
: 'bg-exo-black/30 border-l-transparent border-r-exo-medium-gray/30 border-t-exo-medium-gray/30 border-b-exo-medium-gray/30 hover:bg-white/[0.03]'}"
onclick={() => toggleSelect(trace.taskId)}
onkeydown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
toggleSelect(trace.taskId);
}
}}
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
>
<div class="min-w-0 flex-1">
<a
href="#/traces/{trace.taskId}"
class="text-sm font-mono transition-colors truncate block {isSelected
? 'text-exo-yellow'
: 'text-white hover:text-exo-yellow'}"
onclick={(e) => e.stopPropagation()}
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
>
{trace.taskId}
</a>
@@ -243,11 +160,7 @@
)}
</div>
</div>
<!-- svelte-ignore a11y_click_events_have_key_events -->
<div
class="flex items-center gap-2 shrink-0"
onclick={(e) => e.stopPropagation()}
>
<div class="flex items-center gap-2 shrink-0">
<a
href="#/traces/{trace.taskId}"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"

View File

@@ -81,8 +81,6 @@ from exo.shared.types.api import (
CreateInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
DeleteTracesRequest,
DeleteTracesResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
@@ -346,7 +344,6 @@ class API:
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
self.app.get("/v1/traces")(self.list_traces)
self.app.post("/v1/traces/delete")(self.delete_traces)
self.app.get("/v1/traces/{task_id}")(self.get_trace)
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
@@ -527,15 +524,15 @@ class API:
if (
model_card.model_id,
sharding,
instance_meta,
instance.sharding(),
instance.instance_meta(),
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
sharding=instance.sharding(),
instance_meta=instance.instance_meta(),
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
@@ -544,8 +541,8 @@ class API:
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
instance.sharding(),
instance.instance_meta(),
len(placement_node_ids),
)
)
@@ -1722,10 +1719,7 @@ class API:
return DeleteDownloadResponse(command_id=command.command_id)
def _get_trace_path(self, task_id: str) -> Path:
trace_path = EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
if not trace_path.resolve().is_relative_to(EXO_TRACING_CACHE_DIR.resolve()):
raise HTTPException(status_code=400, detail=f"Invalid task ID: {task_id}")
return trace_path
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
async def list_traces(self) -> TraceListResponse:
traces: list[TraceListItem] = []
@@ -1824,18 +1818,6 @@ class API:
filename=f"trace_{task_id}.json",
)
async def delete_traces(self, request: DeleteTracesRequest) -> DeleteTracesResponse:
deleted: list[str] = []
not_found: list[str] = []
for task_id in request.task_ids:
trace_path = self._get_trace_path(task_id)
if trace_path.exists():
trace_path.unlink()
deleted.append(task_id)
else:
not_found.append(task_id)
return DeleteTracesResponse(deleted=deleted, not_found=not_found)
async def get_onboarding(self) -> JSONResponse:
return JSONResponse({"completed": ONBOARDING_COMPLETE_FILE.exists()})

View File

@@ -437,12 +437,3 @@ class TraceListItem(CamelCaseModel):
class TraceListResponse(CamelCaseModel):
traces: list[TraceListItem]
class DeleteTracesRequest(CamelCaseModel):
task_ids: list[str]
class DeleteTracesResponse(CamelCaseModel):
deleted: list[str]
not_found: list[str]

View File

@@ -4,7 +4,13 @@ from pydantic import model_validator
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
Sharding,
ShardMetadata,
TensorShardMetadata,
)
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -24,16 +30,40 @@ class BaseInstance(TaggedModel):
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)
@staticmethod
def instance_meta() -> InstanceMeta: ...
def sharding(self) -> Sharding:
if all(
isinstance(sm, PipelineShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Pipeline
if all(
isinstance(sm, TensorShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Tensor
raise ValueError("shard metadata malformed")
class MlxRingInstance(BaseInstance):
hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxRing
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxJaccl
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance