mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 23:06:23 -05:00
Compare commits
1 Commits
handle-mes
...
support-ml
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8748eade1 |
@@ -143,7 +143,12 @@ from exo.shared.types.openai_responses import (
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxDevice,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
@@ -310,6 +315,7 @@ class API:
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
mlx_device=payload.mlx_device,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
@@ -350,6 +356,7 @@ class API:
|
||||
sharding: Sharding = Sharding.Pipeline,
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
mlx_device: MlxDevice = MlxDevice.Auto,
|
||||
) -> Instance:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
|
||||
@@ -360,6 +367,7 @@ class API:
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
mlx_device=mlx_device,
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
|
||||
@@ -159,6 +159,7 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
mlx_device=command.mlx_device,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -173,6 +174,7 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
mlx_device=command.mlx_device,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -8,7 +8,12 @@ from pydantic import BaseModel, Field
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxDevice,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -226,6 +231,7 @@ class PlaceInstanceParams(BaseModel):
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
mlx_device: MlxDevice = MlxDevice.Auto
|
||||
|
||||
|
||||
class CreateInstanceParams(BaseModel):
|
||||
|
||||
@@ -8,7 +8,12 @@ from exo.shared.types.api import (
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxDevice,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -38,6 +43,7 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
mlx_device: MlxDevice = MlxDevice.Auto
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
|
||||
@@ -16,9 +16,16 @@ class InstanceMeta(str, Enum):
|
||||
MlxJaccl = "MlxJaccl"
|
||||
|
||||
|
||||
class MlxDevice(str, Enum):
|
||||
Auto = "Auto"
|
||||
Cpu = "Cpu"
|
||||
Gpu = "Gpu"
|
||||
|
||||
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
mlx_device: MlxDevice = MlxDevice.Auto
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -4,7 +4,7 @@ import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxDevice, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
@@ -35,6 +35,15 @@ def entrypoint(
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Set MLX compute device before importing runner (which imports mlx.core at module scope)
|
||||
mlx_device = bound_instance.instance.mlx_device
|
||||
if mlx_device != MlxDevice.Auto:
|
||||
import mlx.core as mx
|
||||
|
||||
device = mx.cpu if mlx_device == MlxDevice.Cpu else mx.gpu
|
||||
mx.set_default_device(device)
|
||||
logger.info(f"MLX device set to: {mlx_device}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
Reference in New Issue
Block a user