mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 21:41:21 -05:00
Compare commits
1 Commits
runner-can
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7f61c3494 |
@@ -69,6 +69,8 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
|
||||
// Granular node state types from the new state structure
|
||||
|
||||
@@ -162,6 +162,12 @@
|
||||
let launchingModelId = $state<string | null>(null);
|
||||
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
|
||||
// Draft model editing state
|
||||
let editingDraftInstanceId = $state<string | null>(null);
|
||||
let editDraftModel = $state<string | null>(null);
|
||||
let editNumDraftTokens = $state<number>(4);
|
||||
let draftEditDropdownSearch = $state("");
|
||||
|
||||
// Custom dropdown state
|
||||
let isModelDropdownOpen = $state(false);
|
||||
let modelDropdownSearch = $state("");
|
||||
@@ -1012,6 +1018,53 @@
|
||||
}
|
||||
}
|
||||
|
||||
// Open draft model edit modal for an instance
|
||||
function openDraftModelEdit(
|
||||
instanceId: string,
|
||||
currentDraftModel: string | null,
|
||||
currentNumTokens: number | null,
|
||||
): void {
|
||||
editingDraftInstanceId = instanceId;
|
||||
editDraftModel = currentDraftModel;
|
||||
editNumDraftTokens = currentNumTokens ?? 4;
|
||||
draftEditDropdownSearch = "";
|
||||
}
|
||||
|
||||
// Close draft model edit modal
|
||||
function closeDraftModelEdit(): void {
|
||||
editingDraftInstanceId = null;
|
||||
editDraftModel = null;
|
||||
editNumDraftTokens = 4;
|
||||
draftEditDropdownSearch = "";
|
||||
}
|
||||
|
||||
// Save draft model settings for an instance
|
||||
async function saveDraftModel(): Promise<void> {
|
||||
if (!editingDraftInstanceId) return;
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/instance/${editingDraftInstanceId}/draft_model`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
draft_model: editDraftModel,
|
||||
num_draft_tokens: editNumDraftTokens,
|
||||
}),
|
||||
},
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error("Failed to set draft model:", errorText);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error setting draft model:", error);
|
||||
} finally {
|
||||
closeDraftModelEdit();
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to unwrap tagged unions like { MlxRingInstance: {...} }
|
||||
function getTagged(obj: unknown): [string | null, unknown] {
|
||||
if (!obj || typeof obj !== "object") return [null, null];
|
||||
@@ -1037,6 +1090,8 @@
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
nodeCount: number;
|
||||
draftModel: string | null;
|
||||
numDraftTokens: number | null;
|
||||
} {
|
||||
const [instanceTag, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== "object") {
|
||||
@@ -1046,6 +1101,8 @@
|
||||
nodeNames: [],
|
||||
nodeIds: [],
|
||||
nodeCount: 0,
|
||||
draftModel: null,
|
||||
numDraftTokens: null,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1063,6 +1120,8 @@
|
||||
nodeToRunner?: Record<string, string>;
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
};
|
||||
|
||||
// Sharding strategy from first shard
|
||||
@@ -1085,12 +1144,18 @@
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
});
|
||||
|
||||
// Draft model for speculative decoding
|
||||
const draftModel = inst.draftModel ?? null;
|
||||
const numDraftTokens = inst.numDraftTokens ?? null;
|
||||
|
||||
return {
|
||||
instanceType,
|
||||
sharding,
|
||||
nodeNames,
|
||||
nodeIds,
|
||||
nodeCount: nodeIds.length,
|
||||
draftModel,
|
||||
numDraftTokens,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1788,12 +1853,42 @@
|
||||
>{id.slice(0, 8).toUpperCase()}</span
|
||||
>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
<div class="flex items-center gap-2">
|
||||
<button
|
||||
onclick={() =>
|
||||
openDraftModelEdit(
|
||||
id,
|
||||
instanceInfo.draftModel,
|
||||
instanceInfo.numDraftTokens,
|
||||
)}
|
||||
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel
|
||||
? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500'
|
||||
: 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
|
||||
title={instanceInfo.draftModel
|
||||
? `Draft: ${instanceInfo.draftModel.split("/").pop()} (${instanceInfo.numDraftTokens}t)`
|
||||
: "Configure speculative decoding"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M13 10V3L4 14h7v7l9-11h-7z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div
|
||||
@@ -1806,6 +1901,13 @@
|
||||
>{instanceInfo.sharding} ({instanceInfo.instanceType})</span
|
||||
>
|
||||
</div>
|
||||
{#if instanceInfo.draftModel}
|
||||
<div class="text-cyan-400/80 text-xs font-mono">
|
||||
Draft: <span class="text-cyan-400"
|
||||
>{instanceInfo.draftModel.split("/").pop()} ({instanceInfo.numDraftTokens}t)</span
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
{#if instanceModelId && instanceModelId !== "Unknown" && instanceModelId !== "Unknown Model"}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
@@ -2710,12 +2812,42 @@
|
||||
>{id.slice(0, 8).toUpperCase()}</span
|
||||
>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
<div class="flex items-center gap-2">
|
||||
<button
|
||||
onclick={() =>
|
||||
openDraftModelEdit(
|
||||
id,
|
||||
instanceInfo.draftModel,
|
||||
instanceInfo.numDraftTokens,
|
||||
)}
|
||||
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel
|
||||
? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500'
|
||||
: 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
|
||||
title={instanceInfo.draftModel
|
||||
? `Draft: ${instanceInfo.draftModel.split("/").pop()} (${instanceInfo.numDraftTokens}t)`
|
||||
: "Configure speculative decoding"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M13 10V3L4 14h7v7l9-11h-7z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div
|
||||
@@ -2728,6 +2860,13 @@
|
||||
>{instanceInfo.sharding} ({instanceInfo.instanceType})</span
|
||||
>
|
||||
</div>
|
||||
{#if instanceInfo.draftModel}
|
||||
<div class="text-cyan-400/80 text-xs font-mono">
|
||||
Draft: <span class="text-cyan-400"
|
||||
>{instanceInfo.draftModel.split("/").pop()} ({instanceInfo.numDraftTokens}t)</span
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
{#if instanceModelId && instanceModelId !== "Unknown" && instanceModelId !== "Unknown Model"}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
@@ -3005,5 +3144,111 @@
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Draft Model Edit Modal -->
|
||||
{#if editingDraftInstanceId}
|
||||
<div
|
||||
class="fixed inset-0 bg-black/80 backdrop-blur-sm z-50 flex items-center justify-center"
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
onclick={closeDraftModelEdit}
|
||||
onkeydown={(e) => e.key === "Escape" && closeDraftModelEdit()}
|
||||
>
|
||||
<div
|
||||
class="bg-exo-dark-gray border border-exo-yellow/30 p-6 max-w-md w-full mx-4"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
onkeydown={(e) => e.stopPropagation()}
|
||||
role="document"
|
||||
>
|
||||
<h3 class="text-exo-yellow font-mono text-lg tracking-wider mb-4">
|
||||
SPECULATIVE DECODING
|
||||
</h3>
|
||||
|
||||
<div class="space-y-4">
|
||||
<!-- Draft Model Selection -->
|
||||
<div>
|
||||
<label
|
||||
class="block text-white/60 text-xs font-mono tracking-wider mb-2"
|
||||
>DRAFT MODEL</label
|
||||
>
|
||||
<div class="relative">
|
||||
<input
|
||||
type="text"
|
||||
bind:value={draftEditDropdownSearch}
|
||||
placeholder={editDraftModel || "Select a draft model..."}
|
||||
class="w-full bg-exo-dark-gray/60 border border-exo-yellow/30 text-white text-sm font-mono px-3 py-2 focus:outline-none focus:border-exo-yellow/60"
|
||||
/>
|
||||
{#if draftEditDropdownSearch}
|
||||
<div
|
||||
class="absolute z-10 w-full mt-1 bg-exo-dark-gray border border-exo-yellow/30 max-h-48 overflow-y-auto"
|
||||
>
|
||||
{#each models.filter((m) => m.id
|
||||
.toLowerCase()
|
||||
.includes(draftEditDropdownSearch.toLowerCase())) as model}
|
||||
<button
|
||||
class="w-full text-left px-3 py-2 text-sm font-mono text-white/80 hover:bg-exo-yellow/10 hover:text-exo-yellow cursor-pointer"
|
||||
onclick={() => {
|
||||
editDraftModel = model.hugging_face_id || model.id;
|
||||
draftEditDropdownSearch = "";
|
||||
}}
|
||||
>
|
||||
{model.name || model.id}
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{#if editDraftModel}
|
||||
<div
|
||||
class="mt-2 flex items-center justify-between text-cyan-400 text-xs font-mono"
|
||||
>
|
||||
<span>{editDraftModel}</span>
|
||||
<button
|
||||
class="text-red-400 hover:text-red-300 cursor-pointer"
|
||||
onclick={() => (editDraftModel = null)}
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Num Draft Tokens -->
|
||||
<div>
|
||||
<label
|
||||
class="block text-white/60 text-xs font-mono tracking-wider mb-2"
|
||||
>DRAFT TOKENS</label
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="1"
|
||||
max="16"
|
||||
bind:value={editNumDraftTokens}
|
||||
class="w-full bg-exo-dark-gray/60 border border-exo-yellow/30 text-white text-sm font-mono px-3 py-2 focus:outline-none focus:border-exo-yellow/60"
|
||||
/>
|
||||
<p class="text-white/40 text-xs font-mono mt-1">
|
||||
Number of tokens to draft per step (1-16)
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Actions -->
|
||||
<div class="flex justify-end gap-3 mt-6">
|
||||
<button
|
||||
onclick={closeDraftModelEdit}
|
||||
class="px-4 py-2 text-xs font-mono tracking-wider uppercase border border-white/30 text-white/60 hover:bg-white/10 hover:text-white transition-all cursor-pointer"
|
||||
>
|
||||
CANCEL
|
||||
</button>
|
||||
<button
|
||||
onclick={saveDraftModel}
|
||||
class="px-4 py-2 text-xs font-mono tracking-wider uppercase border border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 transition-all cursor-pointer"
|
||||
>
|
||||
SAVE
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</main>
|
||||
</div>
|
||||
|
||||
@@ -55,6 +55,8 @@ from exo.shared.types.api import (
|
||||
PlaceInstanceParams,
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
SetDraftModelParams,
|
||||
SetDraftModelResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||
@@ -68,6 +70,7 @@ from exo.shared.types.commands import (
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
SetInstanceDraftModel,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
@@ -215,6 +218,7 @@ class API:
|
||||
self.app.get("/instance/previews")(self.get_placement_previews)
|
||||
self.app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.put("/instance/{instance_id}/draft_model")(self.set_draft_model)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
@@ -238,6 +242,8 @@ class API:
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
draft_model=payload.draft_model,
|
||||
num_draft_tokens=payload.num_draft_tokens,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
@@ -437,6 +443,24 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def set_draft_model(
|
||||
self, instance_id: InstanceId, payload: SetDraftModelParams
|
||||
) -> SetDraftModelResponse:
|
||||
if instance_id not in self.state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
|
||||
command = SetInstanceDraftModel(
|
||||
instance_id=instance_id,
|
||||
draft_model=payload.draft_model,
|
||||
num_draft_tokens=payload.num_draft_tokens,
|
||||
)
|
||||
await self._send(command)
|
||||
return SetDraftModelResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
|
||||
@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
SetInstanceDraftModel,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -31,6 +32,7 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
@@ -278,6 +280,14 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case SetInstanceDraftModel():
|
||||
generated_events.append(
|
||||
InstanceDraftModelUpdated(
|
||||
instance_id=command.instance_id,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -138,6 +138,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -152,6 +154,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -12,6 +12,7 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -60,6 +61,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceDraftModelUpdated():
|
||||
return apply_instance_draft_model_updated(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeDownloadProgress():
|
||||
@@ -178,6 +181,25 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_instance_draft_model_updated(
|
||||
event: InstanceDraftModelUpdated, state: State
|
||||
) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
instance = state.instances[event.instance_id]
|
||||
updated_instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": event.draft_model,
|
||||
"num_draft_tokens": event.num_draft_tokens,
|
||||
}
|
||||
)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
event.instance_id: updated_instance,
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
**state.runners,
|
||||
|
||||
@@ -177,6 +177,8 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
|
||||
num_draft_tokens: int = 3
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
@@ -188,6 +190,8 @@ class PlaceInstanceParams(BaseModel):
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
@@ -340,3 +344,14 @@ class ImageListItem(BaseModel, frozen=True):
|
||||
|
||||
class ImageListResponse(BaseModel, frozen=True):
|
||||
data: list[ImageListItem]
|
||||
|
||||
|
||||
class SetDraftModelParams(BaseModel):
|
||||
draft_model: ModelId | None = None # None to disable speculative decoding
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
class SetDraftModelResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
@@ -7,7 +7,7 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, ModelId, NodeId
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
@@ -38,6 +38,8 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
@@ -48,6 +50,14 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class SetInstanceDraftModel(BaseCommand):
|
||||
"""Set or update the draft model for an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None # None to disable speculative decoding
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -71,6 +81,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| SetInstanceDraftModel
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -67,6 +67,14 @@ class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class InstanceDraftModelUpdated(BaseEvent):
|
||||
"""Draft model updated on an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None
|
||||
num_draft_tokens: int
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
runner_status: RunnerStatus
|
||||
@@ -118,6 +126,7 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceDraftModelUpdated
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeTimedOut
|
||||
|
||||
@@ -40,6 +40,12 @@ class DownloadModel(BaseTask): # emitted by Worker
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class DownloadDraftModel(BaseTask): # emitted by Worker
|
||||
"""Download a draft model for speculative decoding (rank 0 only)."""
|
||||
|
||||
model_id: str # HuggingFace model ID
|
||||
|
||||
|
||||
class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -80,9 +86,17 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class SetDraftModel(BaseTask): # emitted by Worker
|
||||
"""Load or clear a draft model on an already-running instance."""
|
||||
|
||||
model_id: str | None # HuggingFace model ID, or None to clear
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| DownloadDraftModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
@@ -90,4 +104,5 @@ Task = (
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
| SetDraftModel
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.common import Host, Id, ModelId, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,6 +19,8 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -48,6 +48,8 @@ def maybe_quantize_kv_cache(
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
draft_model: Model | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -66,25 +68,30 @@ def warmup_inference(
|
||||
|
||||
tokens_generated = 0
|
||||
|
||||
cache = make_kv_cache(
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Use a default sampler for warmup
|
||||
sampler = make_sampler(temp=0.7)
|
||||
|
||||
generate_kwargs: dict[str, object] = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"prompt": warmup_prompt,
|
||||
"max_tokens": 50,
|
||||
"sampler": sampler,
|
||||
"prefill_step_size": 2048,
|
||||
"kv_group_size": KV_GROUP_SIZE,
|
||||
"kv_bits": KV_BITS,
|
||||
}
|
||||
|
||||
# Warm up with draft model if provided (speculative decoding path)
|
||||
if draft_model is not None:
|
||||
logger.info("Warming up with speculative decoding (draft model)")
|
||||
generate_kwargs["draft_model"] = draft_model
|
||||
generate_kwargs["num_draft_tokens"] = num_draft_tokens
|
||||
else:
|
||||
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
for _r in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=warmup_prompt,
|
||||
max_tokens=50,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
for _r in stream_generate(**generate_kwargs): # type: ignore[arg-type]
|
||||
logger.info("Generated warmup token: " + str(_r.text))
|
||||
tokens_generated += 1
|
||||
|
||||
@@ -120,6 +127,8 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prompt: str,
|
||||
draft_model: Model | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -131,8 +140,6 @@ def mlx_generate(
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
@@ -145,19 +152,31 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
|
||||
# Build kwargs for stream_generate, conditionally adding draft model params
|
||||
generate_kwargs: dict[str, object] = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"sampler": sampler,
|
||||
"logits_processors": logits_processors,
|
||||
"prefill_step_size": 2048,
|
||||
"kv_group_size": KV_GROUP_SIZE,
|
||||
"kv_bits": KV_BITS,
|
||||
}
|
||||
|
||||
# Add speculative decoding parameters if draft model is provided
|
||||
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
|
||||
# as speculative decoding requires cache trimming capabilities
|
||||
if draft_model is not None:
|
||||
generate_kwargs["draft_model"] = draft_model
|
||||
generate_kwargs["num_draft_tokens"] = num_draft_tokens
|
||||
else:
|
||||
# Only use custom cache for non-speculative generation
|
||||
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
|
||||
|
||||
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -230,6 +230,27 @@ def load_mlx_items(
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def load_draft_model(model_id: ModelId) -> nn.Module:
|
||||
"""Load a draft model for speculative decoding (rank 0 only).
|
||||
|
||||
Draft models are small models (typically 0.5B-2B parameters) used to
|
||||
generate candidate tokens quickly, which are then verified by the main
|
||||
model in a single forward pass.
|
||||
|
||||
Assumes the model has already been downloaded by the worker.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID for the draft model
|
||||
|
||||
Returns:
|
||||
The loaded draft model
|
||||
"""
|
||||
model_path = build_model_path(model_id)
|
||||
draft_model, _ = load_model(model_path, strict=True)
|
||||
logger.info(f"Loaded draft model from {model_path}")
|
||||
return draft_model
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
|
||||
@@ -29,8 +29,10 @@ from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
DownloadDraftModel,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
SetDraftModel,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
@@ -51,6 +53,7 @@ from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.impl_shard_downloader import build_full_shard
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
@@ -206,42 +209,10 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard.model_card.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(
|
||||
shard
|
||||
)
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
self._handle_shard_download_process(task, initial_progress)
|
||||
await self._handle_download(shard, task)
|
||||
case DownloadDraftModel(model_id=model_id):
|
||||
shard = await build_full_shard(ModelId(model_id))
|
||||
await self._handle_download(shard, task)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
@@ -292,6 +263,25 @@ class Worker:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case SetDraftModel(
|
||||
model_id=draft_model_id, num_draft_tokens=num_tokens
|
||||
):
|
||||
runner = self.runners[self._task_to_runner_id(task)]
|
||||
await runner.start_task(task)
|
||||
# Update bound_instance to reflect new/cleared draft model
|
||||
updated_instance = runner.bound_instance.instance.model_copy(
|
||||
update={
|
||||
"draft_model": (
|
||||
ModelId(draft_model_id)
|
||||
if draft_model_id is not None
|
||||
else None
|
||||
),
|
||||
"num_draft_tokens": num_tokens,
|
||||
}
|
||||
)
|
||||
runner.bound_instance = runner.bound_instance.model_copy(
|
||||
update={"instance": updated_instance}
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
@@ -386,6 +376,46 @@ class Worker:
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
async def _handle_download(self, shard: ShardMetadata, task: Task) -> None:
|
||||
"""Handle model download - shared logic for main and draft models."""
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
if model_id not in self.download_status:
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
||||
)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
download_task = DownloadModel(
|
||||
instance_id=task.instance_id,
|
||||
shard_metadata=shard,
|
||||
task_id=task.task_id,
|
||||
task_status=task.task_status,
|
||||
)
|
||||
self._handle_shard_download_process(download_task, initial_progress)
|
||||
|
||||
def _handle_shard_download_process(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
|
||||
@@ -8,10 +8,12 @@ from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadDraftModel,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
SetDraftModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
@@ -40,6 +42,16 @@ from exo.shared.types.worker.runners import (
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
def _is_download_in_progress_or_complete(
|
||||
model_id: ModelId,
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> bool:
|
||||
"""Check if model download is in progress or complete."""
|
||||
return model_id in download_status and isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
)
|
||||
|
||||
|
||||
def plan(
|
||||
node_id: NodeId,
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
@@ -59,9 +71,11 @@ def plan(
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _draft_model_needs_download(runners, download_status, instances)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _load_model(runners, all_runners, global_download_status, download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _set_draft_model(runners, instances, download_status)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
)
|
||||
|
||||
@@ -119,12 +133,9 @@ def _model_needs_download(
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
)
|
||||
):
|
||||
if isinstance(
|
||||
runner.status, RunnerIdle
|
||||
) and not _is_download_in_progress_or_complete(model_id, download_status):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
instance_id=runner.bound_instance.instance.instance_id,
|
||||
@@ -132,6 +143,43 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
def _draft_model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
) -> DownloadDraftModel | None:
|
||||
"""Check if draft model needs download for rank 0 runner.
|
||||
|
||||
Triggers download when:
|
||||
- RunnerIdle with draft model (initial setup)
|
||||
- RunnerReady with new draft model (updated via API)
|
||||
"""
|
||||
rank_0_runner = next(
|
||||
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
|
||||
None,
|
||||
)
|
||||
if rank_0_runner is None:
|
||||
return None
|
||||
if not isinstance(rank_0_runner.status, (RunnerIdle, RunnerReady)):
|
||||
return None
|
||||
|
||||
# Use current instance state (may have been updated via API)
|
||||
instance_id = rank_0_runner.bound_instance.instance.instance_id
|
||||
current_instance = instances.get(instance_id)
|
||||
if current_instance is None:
|
||||
return None
|
||||
|
||||
draft_model_id = current_instance.draft_model
|
||||
if draft_model_id is None:
|
||||
return None
|
||||
if _is_download_in_progress_or_complete(draft_model_id, download_status):
|
||||
return None
|
||||
return DownloadDraftModel(
|
||||
instance_id=instance_id,
|
||||
model_id=str(draft_model_id),
|
||||
)
|
||||
|
||||
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -186,10 +234,12 @@ def _load_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> LoadModel | None:
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
shard = runner.bound_instance.bound_shard
|
||||
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
@@ -203,6 +253,14 @@ def _load_model(
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
# Rank 0 with draft model must wait for draft download before loading
|
||||
if shard.device_rank == 0:
|
||||
draft_model_id = instance.draft_model
|
||||
if draft_model_id is not None and not isinstance(
|
||||
download_status.get(draft_model_id), DownloadCompleted
|
||||
):
|
||||
continue
|
||||
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
@@ -262,6 +320,53 @@ def _ready_to_warmup(
|
||||
return None
|
||||
|
||||
|
||||
def _set_draft_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> SetDraftModel | None:
|
||||
"""Check if rank 0 runner needs to load or clear a draft model."""
|
||||
rank_0_runner = next(
|
||||
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
|
||||
None,
|
||||
)
|
||||
if rank_0_runner is None:
|
||||
return None
|
||||
if not isinstance(rank_0_runner.status, RunnerReady):
|
||||
return None
|
||||
|
||||
instance_id = rank_0_runner.bound_instance.instance.instance_id
|
||||
current_instance = instances.get(instance_id)
|
||||
if current_instance is None:
|
||||
return None
|
||||
|
||||
# Compare runner's bound draft model vs current instance draft model
|
||||
runner_draft_model = rank_0_runner.bound_instance.instance.draft_model
|
||||
current_draft_model = current_instance.draft_model
|
||||
|
||||
if runner_draft_model == current_draft_model:
|
||||
return None
|
||||
|
||||
# Draft model changed - need to update
|
||||
if current_draft_model is None:
|
||||
# Clear draft model
|
||||
return SetDraftModel(
|
||||
instance_id=instance_id,
|
||||
model_id=None,
|
||||
num_draft_tokens=4,
|
||||
)
|
||||
|
||||
# Wait for draft model to be downloaded
|
||||
if not isinstance(download_status.get(current_draft_model), DownloadCompleted):
|
||||
return None
|
||||
|
||||
return SetDraftModel(
|
||||
instance_id=instance_id,
|
||||
model_id=str(current_draft_model),
|
||||
num_draft_tokens=current_instance.num_draft_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Literal
|
||||
from typing import Literal, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -32,6 +32,7 @@ from exo.shared.types.tasks import (
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
SetDraftModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
@@ -71,6 +72,7 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
initialize_mlx,
|
||||
load_draft_model,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
@@ -99,6 +101,7 @@ def main(
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -164,6 +167,16 @@ def main(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
# Load draft model for speculative decoding (rank 0 only)
|
||||
if (
|
||||
instance.draft_model is not None
|
||||
and shard_metadata.device_rank == 0
|
||||
):
|
||||
logger.info(f"Loading draft model: {instance.draft_model}")
|
||||
draft_model = cast(
|
||||
Model, load_draft_model(str(instance.draft_model))
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
@@ -185,7 +198,8 @@ def main(
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=instance.num_draft_tokens,
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
@@ -231,6 +245,8 @@ def main(
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=instance.num_draft_tokens,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
@@ -412,6 +428,52 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case SetDraftModel(
|
||||
model_id=draft_model_id, num_draft_tokens=num_tokens
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up (setting draft model)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert model is not None
|
||||
assert tokenizer is not None
|
||||
|
||||
if draft_model_id is None:
|
||||
# Clear draft model
|
||||
logger.info("Clearing draft model")
|
||||
draft_model = None
|
||||
instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": None,
|
||||
"num_draft_tokens": 4,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Load new draft model
|
||||
logger.info(f"Loading draft model: {draft_model_id}")
|
||||
draft_model = cast(
|
||||
Model, load_draft_model(ModelId(draft_model_id))
|
||||
)
|
||||
instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": ModelId(draft_model_id),
|
||||
"num_draft_tokens": num_tokens,
|
||||
}
|
||||
)
|
||||
# Warm up with speculative decoding
|
||||
logger.info("Warming up with new draft model")
|
||||
warmup_inference(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=num_tokens,
|
||||
)
|
||||
logger.info("Draft model loaded and warmed up")
|
||||
|
||||
current_status = RunnerReady()
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
@@ -432,7 +494,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
del model, tokenizer, group, draft_model
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
|
||||
Reference in New Issue
Block a user