mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 11:43:17 -05:00
Compare commits
3 Commits
alexcheema
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90e2a20091 | ||
|
|
55152fa99d | ||
|
|
e7f61c3494 |
@@ -69,6 +69,8 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
|
||||
// Granular node state types from the new state structure
|
||||
|
||||
@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
SetInstanceDraftModel,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
@@ -32,6 +33,7 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
@@ -308,6 +310,14 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case SetInstanceDraftModel():
|
||||
generated_events.append(
|
||||
InstanceDraftModelUpdated(
|
||||
instance_id=command.instance_id,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -147,6 +147,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -161,6 +163,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -12,6 +12,7 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -69,6 +70,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceDraftModelUpdated():
|
||||
return apply_instance_draft_model_updated(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeDownloadProgress():
|
||||
@@ -187,6 +190,25 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_instance_draft_model_updated(
|
||||
event: InstanceDraftModelUpdated, state: State
|
||||
) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
instance = state.instances[event.instance_id]
|
||||
updated_instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": event.draft_model,
|
||||
"num_draft_tokens": event.num_draft_tokens,
|
||||
}
|
||||
)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
event.instance_id: updated_instance,
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
**state.runners,
|
||||
|
||||
@@ -72,6 +72,14 @@ class DeleteDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class SetInstanceDraftModel(BaseCommand):
|
||||
"""Set or update the draft model for an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None # None to disable speculative decoding
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
|
||||
|
||||
@@ -84,6 +92,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| SetInstanceDraftModel
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -68,6 +68,14 @@ class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class InstanceDraftModelUpdated(BaseEvent):
|
||||
"""Draft model updated on an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None
|
||||
num_draft_tokens: int
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
runner_status: RunnerStatus
|
||||
@@ -141,6 +149,7 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceDraftModelUpdated
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeTimedOut
|
||||
|
||||
@@ -40,6 +40,12 @@ class DownloadModel(BaseTask): # emitted by Worker
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class DownloadDraftModel(BaseTask): # emitted by Worker
|
||||
"""Download a draft model for speculative decoding (rank 0 only)."""
|
||||
|
||||
model_id: str # HuggingFace model ID
|
||||
|
||||
|
||||
class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -80,9 +86,17 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class SetDraftModel(BaseTask): # emitted by Worker
|
||||
"""Load or clear a draft model on an already-running instance."""
|
||||
|
||||
model_id: str | None # HuggingFace model ID, or None to clear
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| DownloadDraftModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
@@ -90,4 +104,5 @@ Task = (
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
| SetDraftModel
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.common import Host, Id, ModelId, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,6 +19,8 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -226,6 +226,27 @@ def load_mlx_items(
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def load_draft_model(model_id: ModelId) -> nn.Module:
|
||||
"""Load a draft model for speculative decoding (rank 0 only).
|
||||
|
||||
Draft models are small models (typically 0.5B-2B parameters) used to
|
||||
generate candidate tokens quickly, which are then verified by the main
|
||||
model in a single forward pass.
|
||||
|
||||
Assumes the model has already been downloaded by the worker.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID for the draft model
|
||||
|
||||
Returns:
|
||||
The loaded draft model
|
||||
"""
|
||||
model_path = build_model_path(model_id)
|
||||
draft_model, _ = load_model(model_path, strict=True)
|
||||
logger.info(f"Loaded draft model from {model_path}")
|
||||
return draft_model
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
|
||||
Reference in New Issue
Block a user