mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-15 00:23:07 -05:00
Compare commits
8 Commits
leo/add-gl
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d4f8130c9 | ||
|
|
1c3cc699d4 | ||
|
|
6762a0a8ba | ||
|
|
6524820e5b | ||
|
|
5a28642790 | ||
|
|
90e2a20091 | ||
|
|
55152fa99d | ||
|
|
e7f61c3494 |
151
.mlx_typings/mlx_lm/models/step3p5.pyi
Normal file
151
.mlx_typings/mlx_lm/models/step3p5.pyi
Normal file
@@ -0,0 +1,151 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
vocab_size: int
|
||||
num_attention_heads: int
|
||||
num_attention_groups: int
|
||||
head_dim: int
|
||||
intermediate_size: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
max_position_embeddings: int
|
||||
sliding_window: int
|
||||
layer_types: Optional[List[str]]
|
||||
yarn_only_types: Optional[List[str]]
|
||||
partial_rotary_factors: Optional[List[float]]
|
||||
attention_other_setting: Optional[Dict[str, Any]]
|
||||
use_head_wise_attn_gate: bool
|
||||
moe_num_experts: int
|
||||
moe_top_k: int
|
||||
moe_intermediate_size: int
|
||||
share_expert_dim: int
|
||||
moe_layers_enum: Optional[str]
|
||||
moe_router_scaling_factor: float
|
||||
norm_expert_weight: bool
|
||||
swiglu_limits: Optional[List[float]]
|
||||
swiglu_limits_shared: Optional[List[float]]
|
||||
tie_word_embeddings: bool
|
||||
|
||||
class Step3p5MLP(nn.Module):
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
limit: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Step3p5MoEGate(nn.Module):
|
||||
top_k: int
|
||||
n_routed_experts: int
|
||||
routed_scaling_factor: float
|
||||
norm_topk_prob: bool
|
||||
gate: nn.Linear
|
||||
router_bias: mx.array
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class Step3p5MoE(nn.Module):
|
||||
gate: Step3p5MoEGate
|
||||
switch_mlp: SwitchGLU
|
||||
share_expert: Step3p5MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Step3p5Attention(nn.Module):
|
||||
is_sliding: bool
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
q_norm: nn.Module
|
||||
k_norm: nn.Module
|
||||
use_head_wise_attn_gate: bool
|
||||
g_proj: nn.Linear
|
||||
rope: nn.Module
|
||||
|
||||
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Step3p5DecoderLayer(nn.Module):
|
||||
self_attn: Step3p5Attention
|
||||
is_sliding: bool
|
||||
is_moe_layer: bool
|
||||
mlp: Step3p5MLP | Step3p5MoE
|
||||
input_layernorm: nn.Module
|
||||
post_attention_layernorm: nn.Module
|
||||
|
||||
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Step3p5Model(nn.Module):
|
||||
args: ModelArgs
|
||||
vocab_size: int
|
||||
num_layers: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[Step3p5DecoderLayer]
|
||||
norm: nn.Module
|
||||
_swa_idx: Optional[int]
|
||||
_full_idx: Optional[int]
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[List[Any]] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
args: ModelArgs
|
||||
model_type: str
|
||||
model: Step3p5Model
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Any]] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
|
||||
@property
|
||||
def layers(self) -> list[Step3p5DecoderLayer]: ...
|
||||
def make_cache(self) -> list[Any]: ...
|
||||
@property
|
||||
def cast_predicate(self) -> Any: ...
|
||||
@property
|
||||
def quant_predicate(self) -> Any: ...
|
||||
@@ -806,6 +806,7 @@
|
||||
isFavorite={favorites.has(group.id)}
|
||||
{selectedModelId}
|
||||
{canModelFit}
|
||||
{getModelFitStatus}
|
||||
onToggleExpand={() => toggleGroupExpanded(group.id)}
|
||||
onSelectModel={handleSelect}
|
||||
{onToggleFavorite}
|
||||
|
||||
@@ -72,6 +72,8 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
|
||||
// Granular node state types from the new state structure
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-4bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "4bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 114572190076
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-6bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "6bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 159039627774
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "8bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 209082699847
|
||||
@@ -24,6 +24,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
SetInstanceDraftModel,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
@@ -35,6 +36,7 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
@@ -319,6 +321,14 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case SetInstanceDraftModel():
|
||||
generated_events.append(
|
||||
InstanceDraftModelUpdated(
|
||||
instance_id=command.instance_id,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -153,6 +153,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -167,6 +169,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -12,6 +12,7 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -72,6 +73,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceDraftModelUpdated():
|
||||
return apply_instance_draft_model_updated(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeDownloadProgress():
|
||||
@@ -190,6 +193,25 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_instance_draft_model_updated(
|
||||
event: InstanceDraftModelUpdated, state: State
|
||||
) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
instance = state.instances[event.instance_id]
|
||||
updated_instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": event.draft_model,
|
||||
"num_draft_tokens": event.num_draft_tokens,
|
||||
}
|
||||
)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
event.instance_id: updated_instance,
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
**state.runners,
|
||||
|
||||
@@ -189,6 +189,7 @@ class ConfigData(BaseModel):
|
||||
["MiniMaxM2ForCausalLM"],
|
||||
["LlamaForCausalLM"],
|
||||
["GptOssForCausalLM"],
|
||||
["Step3p5ForCausalLM"],
|
||||
]
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -38,6 +38,8 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
draft_model: ModelId | None = None
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
@@ -72,6 +74,14 @@ class DeleteDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class SetInstanceDraftModel(BaseCommand):
|
||||
"""Set or update the draft model for an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None # None to disable speculative decoding
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
class CancelDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
@@ -89,6 +99,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| SetInstanceDraftModel
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -68,6 +68,14 @@ class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class InstanceDraftModelUpdated(BaseEvent):
|
||||
"""Draft model updated on an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None
|
||||
num_draft_tokens: int
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
runner_status: RunnerStatus
|
||||
@@ -141,6 +149,7 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceDraftModelUpdated
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeTimedOut
|
||||
|
||||
@@ -40,6 +40,12 @@ class DownloadModel(BaseTask): # emitted by Worker
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class DownloadDraftModel(BaseTask): # emitted by Worker
|
||||
"""Download a draft model for speculative decoding (rank 0 only)."""
|
||||
|
||||
model_id: str # HuggingFace model ID
|
||||
|
||||
|
||||
class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -80,9 +86,17 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class SetDraftModel(BaseTask): # emitted by Worker
|
||||
"""Load or clear a draft model on an already-running instance."""
|
||||
|
||||
model_id: str | None # HuggingFace model ID, or None to clear
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| DownloadDraftModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
@@ -90,4 +104,5 @@ Task = (
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
| SetDraftModel
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.common import Host, Id, ModelId, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,6 +19,8 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -35,6 +35,9 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.step3p5 import Model as Step35Model
|
||||
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
|
||||
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
@@ -264,6 +267,19 @@ def pipeline_auto_parallel(
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, Step35InnerModel):
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
sliding_layers = [
|
||||
i for i, layer in enumerate(layers) if getattr(layer, "is_sliding", False)
|
||||
]
|
||||
full_layers = [
|
||||
i
|
||||
for i, layer in enumerate(layers)
|
||||
if not getattr(layer, "is_sliding", True)
|
||||
]
|
||||
inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]
|
||||
inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
assert isinstance(layers, list), (
|
||||
@@ -427,6 +443,14 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Step35Model):
|
||||
tensor_parallel_sharding_strategy = Step35ShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -981,3 +1005,46 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
class Step35ShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Step35Model, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_heads //= self.N
|
||||
layer.self_attn.num_kv_heads //= self.N
|
||||
|
||||
if getattr(layer.self_attn, "use_head_wise_attn_gate", False):
|
||||
layer.self_attn.g_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.g_proj
|
||||
)
|
||||
|
||||
if isinstance(layer.mlp, Step35MLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
else:
|
||||
layer.mlp.sharding_group = self.group
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
@@ -223,6 +223,27 @@ def load_mlx_items(
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def load_draft_model(model_id: ModelId) -> nn.Module:
|
||||
"""Load a draft model for speculative decoding (rank 0 only).
|
||||
|
||||
Draft models are small models (typically 0.5B-2B parameters) used to
|
||||
generate candidate tokens quickly, which are then verified by the main
|
||||
model in a single forward pass.
|
||||
|
||||
Assumes the model has already been downloaded by the worker.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID for the draft model
|
||||
|
||||
Returns:
|
||||
The loaded draft model
|
||||
"""
|
||||
model_path = build_model_path(model_id)
|
||||
draft_model, _ = load_model(model_path, strict=True)
|
||||
logger.info(f"Loaded draft model from {model_path}")
|
||||
return draft_model
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
|
||||
Reference in New Issue
Block a user