Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
f8748eade1 Add MLX compute device backend selection (cpu/gpu/auto)
Add MlxDevice enum (Auto/Cpu/Gpu) that flows from the API through
PlaceInstance commands to instance creation and runner initialization.
The device is set via mx.set_default_device() in bootstrap.py before
the runner imports MLX, ensuring all subsequent operations use the
chosen device.

This enables Linux CPU-only machines to participate in inference and
allows forcing CPU mode for testing or when GPU has issues. Default
is Auto (current behavior) for full backward compatibility.

Closes #958

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 10:52:11 -08:00
7 changed files with 42 additions and 12 deletions

View File

@@ -143,7 +143,12 @@ from exo.shared.types.openai_responses import (
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxDevice,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -310,6 +315,7 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
mlx_device=payload.mlx_device,
)
await self._send(command)
@@ -350,6 +356,7 @@ class API:
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
mlx_device: MlxDevice = MlxDevice.Auto,
) -> Instance:
model_card = await ModelCard.load(model_id)
@@ -360,6 +367,7 @@ class API:
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
mlx_device=mlx_device,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,

View File

@@ -159,6 +159,7 @@ def place_instance(
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
mlx_device=command.mlx_device,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -173,6 +174,7 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
mlx_device=command.mlx_device,
)
return target_instances

View File

@@ -211,14 +211,6 @@ class Router:
pass
except AllQueuesFullError:
logger.warning(f"All peer queues full, dropping message on {topic}")
except RuntimeError as e:
if "MessageTooLarge" in str(e):
logger.error(
f"Message too large for gossipsub on topic {topic} "
f"({len(data)} bytes), dropping message"
)
else:
raise
def get_node_id_keypair(

View File

@@ -8,7 +8,12 @@ from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxDevice,
)
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
@@ -226,6 +231,7 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
mlx_device: MlxDevice = MlxDevice.Auto
class CreateInstanceParams(BaseModel):

View File

@@ -8,7 +8,12 @@ from exo.shared.types.api import (
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxDevice,
)
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -38,6 +43,7 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
mlx_device: MlxDevice = MlxDevice.Auto
class CreateInstance(BaseCommand):

View File

@@ -16,9 +16,16 @@ class InstanceMeta(str, Enum):
MlxJaccl = "MlxJaccl"
class MlxDevice(str, Enum):
Auto = "Auto"
Cpu = "Cpu"
Gpu = "Gpu"
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
mlx_device: MlxDevice = MlxDevice.Auto
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -4,7 +4,7 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import BoundInstance, MlxDevice, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -35,6 +35,15 @@ def entrypoint(
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Set MLX compute device before importing runner (which imports mlx.core at module scope)
mlx_device = bound_instance.instance.mlx_device
if mlx_device != MlxDevice.Auto:
import mlx.core as mx
device = mx.cpu if mlx_device == MlxDevice.Cpu else mx.gpu
mx.set_default_device(device)
logger.info(f"MLX device set to: {mlx_device}")
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main