Compare commits

...

1 Commits

Author SHA1 Message Date
Alex Cheema
c93376f0fb Add speculative decoding support with draft models
Implements speculative decoding using MLX-LM's built-in stream_generate(draft_model=...)
to accelerate inference. A small draft model generates candidate tokens which are
verified by the main model in a single forward pass.

Key changes:
- Add draft_model and num_draft_tokens to instance configuration
- Auto-download draft models during warmup if not present
- Dashboard UI for selecting draft model and token count
- Display draft model info on running instance cards

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 02:42:22 +00:00
11 changed files with 277 additions and 67 deletions

View File

@@ -69,6 +69,8 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
interface RawNodeProfile {

View File

@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -58,6 +58,8 @@ const sidebarVisible = $derived(chatSidebarVisible());
sharding: 'Pipeline' | 'Tensor';
instanceType: InstanceMeta;
minNodes: number;
draftModel: string | null;
numDraftTokens: number;
}
function saveLaunchDefaults(): void {
@@ -66,6 +68,8 @@ const sidebarVisible = $derived(chatSidebarVisible());
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
draftModel: selectedDraftModel,
numDraftTokens: selectedNumDraftTokens,
};
try {
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
@@ -88,24 +92,36 @@ const sidebarVisible = $derived(chatSidebarVisible());
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
}
// Apply draft model if it exists in the available models (check against hugging_face_id)
if (defaults.draftModel && availableModels.some(m => (m as {hugging_face_id?: string}).hugging_face_id === defaults.draftModel)) {
selectedDraftModel = defaults.draftModel;
}
// Apply num draft tokens if valid
if (defaults.numDraftTokens && defaults.numDraftTokens >= 1 && defaults.numDraftTokens <= 10) {
selectedNumDraftTokens = defaults.numDraftTokens;
}
}
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
let selectedMinNodes = $state<number>(1);
let selectedDraftModel = $state<string | null>(null);
let selectedNumDraftTokens = $state<number>(4);
let minNodesInitialized = $state(false);
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
@@ -113,6 +129,8 @@ let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state('');
let isDraftModelDropdownOpen = $state(false);
let draftModelDropdownSearch = $state('');
// Slider dragging state
let isDraggingSlider = $state(false);
@@ -362,47 +380,39 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
if (!modelId || launchingModelId) return;
launchingModelId = modelId;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
const preview = specificPreview ?? filteredPreview();
let instanceData: unknown;
if (preview?.instance) {
// Use the instance from the preview
instanceData = preview.instance;
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
);
if (!placementResponse.ok) {
const errorText = await placementResponse.text();
console.error('Failed to get placement:', errorText);
return;
}
instanceData = await placementResponse.json();
}
// POST the instance to create it
const response = await fetch('/instance', {
let response: Response;
// Use /place_instance endpoint - it handles placement and creation in one step
// This also supports draft_model for speculative decoding
const placePayload = {
model_id: modelId,
sharding: preview?.sharding ?? selectedSharding,
instance_meta: preview?.instance_meta ?? selectedInstanceType,
min_nodes: selectedMinNodes,
draft_model: selectedDraftModel,
num_draft_tokens: selectedDraftModel ? selectedNumDraftTokens : 4,
};
response = await fetch('/place_instance', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ instance: instanceData })
body: JSON.stringify(placePayload)
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
const scrollToBottom = () => {
@@ -816,30 +826,34 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
nodeNames: string[];
nodeIds: string[];
nodeCount: number;
draftModel: string | null;
numDraftTokens: number | null;
} {
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
}
// Instance type from tag
let instanceType = 'Unknown';
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
runnerToShard?: Record<string, unknown>;
}
};
draftModel?: string;
numDraftTokens?: number;
};
// Sharding strategy from first shard
let sharding = 'Unknown';
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
@@ -850,7 +864,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
}
// Node names from topology
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
const nodeIds = Object.keys(nodeToRunner);
@@ -858,8 +872,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
});
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
// Draft model for speculative decoding
const draftModel = inst.draftModel ?? null;
const numDraftTokens = inst.numDraftTokens ?? null;
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
}
function formatLastUpdate(): string {
@@ -1345,6 +1363,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -1678,8 +1699,80 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{/each}
</div>
</div>
<!-- Draft Model (Speculative Decoding) -->
<div>
<div class="text-xs text-white/70 font-mono mb-2">Draft Model (Speculative):</div>
<div class="relative">
<button
onclick={() => { isDraftModelDropdownOpen = !isDraftModelDropdownOpen; draftModelDropdownSearch = ''; }}
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {selectedDraftModel ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="truncate">{selectedDraftModel ? selectedDraftModel.split('/').pop() : 'None'}</span>
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftModelDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
</svg>
</button>
{#if isDraftModelDropdownOpen}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="fixed inset-0 z-40"
onclick={() => isDraftModelDropdownOpen = false}
onkeydown={(e) => e.key === 'Escape' && (isDraftModelDropdownOpen = false)}
></div>
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
<div class="p-2 border-b border-exo-medium-gray/30">
<input
type="text"
bind:value={draftModelDropdownSearch}
placeholder="Search models..."
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-exo-yellow/50"
/>
</div>
<div class="overflow-y-auto max-h-36">
<!-- None option -->
<button
onclick={() => { selectedDraftModel = null; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {selectedDraftModel === null ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
>
<span>None</span>
</button>
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftModelDropdownSearch.toLowerCase()) && m.id !== selectedModelId) as model}
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
{@const modelHfId = model.hugging_face_id ?? model.id}
<button
onclick={() => { selectedDraftModel = modelHfId; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedDraftModel === modelHfId ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex-shrink-0 text-xs text-white/50">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
{/each}
</div>
</div>
{/if}
</div>
</div>
<!-- Draft Tokens (only show when draft model selected) -->
{#if selectedDraftModel}
<div class="flex items-center gap-2 mt-2">
<span class="text-xs text-white/50 font-mono">Tokens:</span>
<div class="flex items-center gap-1">
{#each [2, 3, 4, 5, 6] as n}
<button
onclick={() => { selectedNumDraftTokens = n; saveLaunchDefaults(); }}
class="w-6 h-6 text-xs font-mono rounded transition-all {selectedNumDraftTokens === n ? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/50' : 'text-white/50 hover:text-white/80 border border-transparent'}"
>{n}</button>
{/each}
</div>
</div>
{/if}
</div>
<!-- Selected Model Preview -->
<div class="space-y-3">
{#if models.length === 0}

View File

@@ -200,6 +200,8 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)

View File

@@ -151,6 +151,8 @@ def place_instance(
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -164,6 +166,8 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -161,6 +161,8 @@ class ChatCompletionTaskParams(BaseModel):
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
num_draft_tokens: int = 3
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
@@ -172,6 +174,8 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod

View File

@@ -2,7 +2,7 @@ from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -25,6 +25,8 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
class CreateInstance(BaseCommand):

View File

@@ -3,6 +3,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +20,8 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -119,6 +119,8 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -135,8 +137,6 @@ def mlx_generate(
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -149,19 +149,31 @@ def mlx_generate(
)
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
# Build kwargs for stream_generate, conditionally adding draft model params
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": prompt,
"max_tokens": max_tokens,
"sampler": sampler,
"logits_processors": logits_processors,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Add speculative decoding parameters if draft model is provided
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
# as speculative decoding requires cache trimming capabilities
if draft_model is not None:
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
# Only use custom cache for non-speculative generation
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
logger.info(out.text)
stats: GenerationStats | None = None

View File

@@ -258,6 +258,27 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: str) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,

View File

@@ -3,7 +3,8 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -35,6 +36,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -57,6 +59,7 @@ def plan(
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _draft_model_needs_download(runners, download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
)
@@ -128,6 +131,57 @@ def _model_needs_download(
)
def _draft_model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
"""Check if draft model needs download (for speculative decoding).
Only rank 0 needs the draft model, and only after the main model is loaded.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
shard = runner.bound_instance.bound_shard
# Only check when runner is loaded and ready for warmup
if not isinstance(runner.status, RunnerLoaded):
continue
# Only rank 0 loads the draft model
if shard.device_rank != 0:
continue
# Check if instance has a draft model configured
draft_model_id = instance.draft_model
if draft_model_id is None:
continue
# Check if draft model needs download
if draft_model_id not in download_status or not isinstance(
download_status[draft_model_id], (DownloadOngoing, DownloadCompleted)
):
# Create minimal shard metadata for draft model download
draft_shard = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=draft_model_id,
pretty_name=str(draft_model_id),
storage_size=Memory.from_bytes(0), # Unknown, will be determined during download
n_layers=1, # Placeholder
hidden_size=1, # Placeholder
supports_tensor=False,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
return DownloadModel(
instance_id=instance.instance_id,
shard_metadata=draft_shard,
)
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],

View File

@@ -56,6 +56,7 @@ from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_draft_model,
load_mlx_items,
mlx_force_oom,
)
@@ -110,6 +111,7 @@ def main(
model = None
tokenizer = None
group = None
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -178,11 +180,20 @@ def main(
)
)
# Load draft model for speculative decoding (rank 0 only)
if (
instance.draft_model is not None
and shard_metadata.device_rank == 0
):
logger.info(f"Loading draft model: {instance.draft_model}")
draft_model = cast(
Model, load_draft_model(str(instance.draft_model))
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=cast(Model, model),
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
@@ -212,11 +223,13 @@ def main(
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
# Generate responses (draft_model loaded at warmup if configured)
mlx_generator = mlx_generate(
model=cast(Model, model),
tokenizer=tokenizer,
task=task_params,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
)
# GPT-OSS specific parsing to match other model formats.
@@ -265,7 +278,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del model, tokenizer, group, draft_model
mx.clear_cache()
import gc