mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-01 01:33:03 -05:00
Compare commits
4 Commits
ciaran/par
...
pyproject-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c890bd0beb | ||
|
|
9ba61f3733 | ||
|
|
d9eca75895 | ||
|
|
9dabde7e57 |
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"mlx-lm==0.30.5",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -31,8 +31,6 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
|
||||
# dependencies only required for development
|
||||
@@ -46,12 +44,6 @@ dev = [
|
||||
"ruff>=0.11.13",
|
||||
]
|
||||
|
||||
# mlx[cuda] requires a newer version of mlx. the ideal on linux is: default to mlx[cpu] unless[cuda] specified.
|
||||
[project.optional-dependencies]
|
||||
# cuda = [
|
||||
# "mlx[cuda]==0.26.3",
|
||||
# ]
|
||||
|
||||
###
|
||||
# workspace configuration
|
||||
###
|
||||
@@ -63,8 +55,8 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
|
||||
@@ -116,7 +108,7 @@ environments = [
|
||||
###
|
||||
|
||||
[tool.ruff]
|
||||
extend-exclude = ["shared/protobufs/**", "*mlx_typings/**", "rust/exo_pyo3_bindings/**"]
|
||||
extend-exclude = ["*mlx_typings/**", "rust/exo_pyo3_bindings/**"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
extend-select = ["I", "N", "B", "A", "PIE", "SIM"]
|
||||
|
||||
@@ -166,9 +166,8 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
logger.error("Error downloading shard:", e)
|
||||
logger.warning(f"Error downloading shard: {type(e).__name__}")
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
@@ -66,7 +65,9 @@ from exo.shared.types.api import (
|
||||
StartDownloadParams,
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
StreamOptions,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
@@ -113,17 +114,10 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
return f"image/{image_format or 'png'}"
|
||||
|
||||
|
||||
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
|
||||
"""Ensure advanced params has a seed set for distributed consistency."""
|
||||
if params is None:
|
||||
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
|
||||
if params.seed is None:
|
||||
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
|
||||
return params
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
chunk: TokenChunk | ToolCallChunk,
|
||||
command_id: CommandId,
|
||||
usage: Usage | None,
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -148,6 +142,7 @@ def chunk_to_response(
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -532,9 +527,10 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, stream_options: StreamOptions | None = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
include_usage = stream_options.include_usage if stream_options else False
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
@@ -550,8 +546,10 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
usage = chunk.usage if include_usage else None
|
||||
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
chunk, command_id, usage=usage
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
@@ -569,6 +567,7 @@ class API:
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
@@ -593,6 +592,9 @@ class API:
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.usage is not None:
|
||||
usage = chunk.usage
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
@@ -614,6 +616,7 @@ class API:
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
@@ -701,7 +704,7 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
self._generate_chat_stream(command.command_id, payload.stream_options),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -782,9 +785,6 @@ class API:
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1033,9 +1033,6 @@ class API:
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1067,7 +1064,6 @@ class API:
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
advanced_params = _ensure_seed(advanced_params)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
@@ -94,35 +94,20 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
# Determine CFG parallelism topology
|
||||
# CFG parallel only for even node counts with CFG models (2+ nodes)
|
||||
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
|
||||
cfg_world_size = 2 if use_cfg_parallel else 1
|
||||
pipeline_world_size = world_size // cfg_world_size
|
||||
|
||||
# For CFG parallel, we only need to allocate layers for one pipeline group
|
||||
# (both CFG groups run the same layers). Use the first pipeline group's nodes.
|
||||
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
|
||||
pipeline_memory = sum(
|
||||
(node_memory[node_id].ram_available for node_id in pipeline_node_ids),
|
||||
start=Memory(),
|
||||
)
|
||||
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=total_layers,
|
||||
memory_fractions=[
|
||||
node_memory[node_id].ram_available.in_bytes / pipeline_memory.in_bytes
|
||||
for node_id in pipeline_node_ids
|
||||
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
|
||||
for node_id in cycle.node_ids
|
||||
],
|
||||
)
|
||||
|
||||
# Validate each pipeline node has sufficient memory for its assigned layers
|
||||
# Use integer arithmetic to avoid floating point precision issues
|
||||
total_storage_bytes = model_card.storage_size.in_bytes
|
||||
for i, node_id in enumerate(pipeline_node_ids):
|
||||
node_layers = layer_allocations[i]
|
||||
# Integer division then multiply to get conservative estimate
|
||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
||||
# Validate each node has sufficient memory for its assigned layers
|
||||
memory_per_layer = model_card.storage_size.in_bytes / total_layers
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
required_memory = node_layers * memory_per_layer
|
||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
@@ -131,69 +116,24 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
# CFG group 0: pipeline ranks in ascending order (0, 1, 2, ...)
|
||||
# CFG group 1: pipeline ranks in descending order (reversed)
|
||||
# This places both "last stages" as ring neighbors for CFG exchange.
|
||||
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
|
||||
(1, r) for r in reversed(range(pipeline_world_size))
|
||||
]
|
||||
|
||||
cfg_pipeline_to_device: dict[tuple[int, int], int] = {
|
||||
(cfg_rank, pipeline_rank): i
|
||||
for i, (cfg_rank, pipeline_rank) in enumerate(position_to_cfg_pipeline)
|
||||
}
|
||||
|
||||
for i, node_id in enumerate(cycle.node_ids):
|
||||
cfg_rank, pipeline_rank = position_to_cfg_pipeline[i]
|
||||
|
||||
layers_before = sum(layer_allocations[:pipeline_rank])
|
||||
node_layers = layer_allocations[pipeline_rank]
|
||||
|
||||
is_first_stage = pipeline_rank == 0
|
||||
is_last_stage = pipeline_rank == pipeline_world_size - 1
|
||||
|
||||
if is_last_stage:
|
||||
next_pipeline_device = None
|
||||
else:
|
||||
next_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank + 1)]
|
||||
|
||||
if is_first_stage:
|
||||
prev_pipeline_device = None
|
||||
else:
|
||||
prev_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank - 1)]
|
||||
|
||||
if is_last_stage and use_cfg_parallel:
|
||||
other_cfg_rank = 1 - cfg_rank
|
||||
cfg_peer_device = cfg_pipeline_to_device[(other_cfg_rank, pipeline_rank)]
|
||||
else:
|
||||
cfg_peer_device = None
|
||||
|
||||
first_pipeline_device = cfg_pipeline_to_device[(cfg_rank, 0)]
|
||||
last_pipeline_device = cfg_pipeline_to_device[
|
||||
(cfg_rank, pipeline_world_size - 1)
|
||||
]
|
||||
|
||||
layers_assigned = 0
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
runner_id = RunnerId()
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=layers_before,
|
||||
end_layer=layers_before + node_layers,
|
||||
start_layer=layers_assigned,
|
||||
end_layer=layers_assigned + node_layers,
|
||||
n_layers=total_layers,
|
||||
cfg_rank=cfg_rank,
|
||||
cfg_world_size=cfg_world_size,
|
||||
explicit_pipeline_rank=pipeline_rank,
|
||||
next_pipeline_device=next_pipeline_device,
|
||||
prev_pipeline_device=prev_pipeline_device,
|
||||
cfg_peer_device=cfg_peer_device,
|
||||
first_pipeline_device=first_pipeline_device,
|
||||
last_pipeline_device=last_pipeline_device,
|
||||
)
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
|
||||
@@ -5,7 +5,6 @@ from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_shard_assignments_for_pipeline_parallel,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
@@ -21,7 +20,7 @@ from exo.shared.types.profiling import (
|
||||
NodeNetworkInfo,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory():
|
||||
@@ -488,195 +487,3 @@ def test_get_shard_assignments_insufficient_memory_raises():
|
||||
get_shard_assignments(
|
||||
model_card, selected_cycle, Sharding.Pipeline, node_memory
|
||||
)
|
||||
|
||||
|
||||
class TestCfgParallelPlacement:
|
||||
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
|
||||
topology = Topology()
|
||||
for node_id in node_ids:
|
||||
topology.add_node(node_id)
|
||||
|
||||
for i, node_id in enumerate(node_ids):
|
||||
next_node = node_ids[(i + 1) % len(node_ids)]
|
||||
conn = Connection(
|
||||
source=node_id,
|
||||
sink=next_node,
|
||||
edge=create_socket_connection(i + 1),
|
||||
)
|
||||
topology.add_connection(conn)
|
||||
|
||||
return topology
|
||||
|
||||
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
|
||||
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
# Both nodes should have all layers (no pipeline split)
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
assert shard.start_layer == 0
|
||||
assert shard.end_layer == 60
|
||||
assert shard.cfg_world_size == 2
|
||||
|
||||
cfg_ranks = sorted(
|
||||
s.cfg_rank for s in shards if isinstance(s, PipelineShardMetadata)
|
||||
)
|
||||
assert cfg_ranks == [0, 1]
|
||||
|
||||
def test_four_nodes_cfg_model_uses_hybrid(self):
|
||||
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
|
||||
nodes = [NodeId() for _ in range(4)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 4]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 4
|
||||
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
assert shard.cfg_world_size == 2
|
||||
assert shard.pipeline_world_size == 2
|
||||
|
||||
# Check we have 2 nodes in each CFG group
|
||||
cfg_0_shards = [
|
||||
s
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 0
|
||||
]
|
||||
cfg_1_shards = [
|
||||
s
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 1
|
||||
]
|
||||
assert len(cfg_0_shards) == 2
|
||||
assert len(cfg_1_shards) == 2
|
||||
|
||||
# Both CFG groups should have the same layer assignments
|
||||
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
|
||||
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
|
||||
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
|
||||
|
||||
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
|
||||
"""Three nodes (odd) with CFG model should use sequential CFG."""
|
||||
nodes = [NodeId() for _ in range(3)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 3]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 3
|
||||
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
# cfg_world_size = 1 means sequential CFG
|
||||
assert shard.cfg_world_size == 1
|
||||
assert shard.cfg_rank == 0
|
||||
|
||||
def test_two_nodes_non_cfg_model_uses_pipeline(self):
|
||||
"""Two nodes with non-CFG model should use pure pipeline."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("flux-test"),
|
||||
n_layers=57,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=False, # Non-CFG model
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
# cfg_world_size = 1 means no CFG parallel
|
||||
assert shard.cfg_world_size == 1
|
||||
assert shard.cfg_rank == 0
|
||||
|
||||
# Should have actual layer sharding (pipeline)
|
||||
layer_ranges = sorted(
|
||||
(s.start_layer, s.end_layer)
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata)
|
||||
)
|
||||
# First shard starts at 0, last shard ends at 57
|
||||
assert layer_ranges[0][0] == 0
|
||||
assert layer_ranges[-1][1] == 57
|
||||
|
||||
@@ -47,7 +47,6 @@ class ModelCard(CamelCaseModel):
|
||||
supports_tensor: bool
|
||||
tasks: list[ModelTask]
|
||||
components: list[ComponentInfo] | None = None
|
||||
uses_cfg: bool = False
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -563,7 +562,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
uses_cfg=True,
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
@@ -598,7 +596,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
uses_cfg=True,
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
@@ -684,7 +681,6 @@ def _generate_image_model_quant_variants(
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
uses_cfg=base_card.uses_cfg,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
)
|
||||
}
|
||||
@@ -704,7 +700,6 @@ def _generate_image_model_quant_variants(
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
uses_cfg=base_card.uses_cfg,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, ConfigDict, TaggedModel
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
@@ -116,8 +116,8 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: PromptTokensDetails | None = None
|
||||
completion_tokens_details: CompletionTokensDetails | None = None
|
||||
prompt_tokens_details: PromptTokensDetails
|
||||
completion_tokens_details: CompletionTokensDetails
|
||||
|
||||
|
||||
class StreamingChoiceResponse(BaseModel):
|
||||
@@ -170,7 +170,13 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: bool = False
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(TaggedModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
messages: list[ChatCompletionMessage]
|
||||
@@ -184,6 +190,7 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -17,6 +17,7 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -28,6 +29,7 @@ class ErrorChunk(BaseChunk):
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["tool_calls"] = "tool_calls"
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
@@ -22,7 +23,7 @@ class TestCommand(BaseCommand):
|
||||
|
||||
|
||||
class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
request_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
@@ -54,7 +55,7 @@ class StartWarmup(BaseTask): # emitted by Worker
|
||||
|
||||
class ChatCompletion(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ChatCompletionTaskParams
|
||||
task_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
@@ -6,6 +6,7 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
@@ -24,6 +25,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
@@ -57,6 +59,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -57,62 +57,8 @@ class PipelineShardMetadata(BaseShardMetadata):
|
||||
|
||||
Layers are represented as a half-open interval [start_layer, end_layer),
|
||||
where start_layer is inclusive and end_layer is exclusive.
|
||||
|
||||
CFG parallelism fields:
|
||||
- cfg_rank: 0 = positive branch, 1 = negative branch (or 0 if no CFG parallel)
|
||||
- cfg_world_size: 1 = sequential CFG, 2 = parallel CFG
|
||||
|
||||
Communication rank fields (explicit to support ring topology):
|
||||
- next_pipeline_device: device to send to in pipeline forward pass
|
||||
- prev_pipeline_device: device to receive from in pipeline forward pass
|
||||
- cfg_peer_device: device for CFG exchange (last stage only)
|
||||
- first_pipeline_device: device of first stage in same CFG group (for latent return)
|
||||
"""
|
||||
|
||||
cfg_rank: int = 0
|
||||
cfg_world_size: int = 1
|
||||
|
||||
# Explicit pipeline position (CFG group 1 uses reversed pipeline order)
|
||||
explicit_pipeline_rank: int | None = None
|
||||
|
||||
next_pipeline_device: int | None = None
|
||||
prev_pipeline_device: int | None = None
|
||||
cfg_peer_device: int | None = None
|
||||
first_pipeline_device: int | None = None
|
||||
last_pipeline_device: int | None = None
|
||||
|
||||
@property
|
||||
def pipeline_world_size(self) -> int:
|
||||
return self.world_size // self.cfg_world_size
|
||||
|
||||
@property
|
||||
def pipeline_rank(self) -> int:
|
||||
if self.explicit_pipeline_rank is not None:
|
||||
return self.explicit_pipeline_rank
|
||||
return self.device_rank % self.pipeline_world_size
|
||||
|
||||
@property
|
||||
def is_pipeline_first(self) -> bool:
|
||||
return self.pipeline_rank == 0
|
||||
|
||||
@property
|
||||
def is_pipeline_last(self) -> bool:
|
||||
return self.pipeline_rank == self.pipeline_world_size - 1
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.model_card.model_id,
|
||||
self.start_layer,
|
||||
self.end_layer,
|
||||
self.n_layers,
|
||||
self.device_rank,
|
||||
self.world_size,
|
||||
self.cfg_rank,
|
||||
self.cfg_world_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TensorShardMetadata(BaseShardMetadata):
|
||||
pass
|
||||
|
||||
@@ -37,12 +37,7 @@ class DistributedImageModel:
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
has_layer_sharding = (
|
||||
shard_metadata.start_layer != 0
|
||||
or shard_metadata.end_layer != shard_metadata.n_layers
|
||||
)
|
||||
|
||||
if group is not None and has_layer_sharding:
|
||||
if group is not None:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
|
||||
@@ -86,27 +86,6 @@ class PromptData(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Get embeddings for a single CFG branch (positive or negative).
|
||||
|
||||
Used for sequential CFG and CFG parallel modes where we process
|
||||
one branch at a time instead of batching.
|
||||
|
||||
Args:
|
||||
positive: True for positive prompt, False for negative prompt
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- embeds: [1, seq, hidden] prompt embeddings
|
||||
- mask: [1, seq] attention mask or None
|
||||
- pooled: [1, hidden] pooled embeddings or None
|
||||
- conditioning_latents: [1, latent_seq, latent_dim] or None
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
|
||||
_config: ImageModelConfig
|
||||
|
||||
@@ -64,12 +64,6 @@ class FluxPromptData(PromptData):
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
return None
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Flux doesn't use CFG, but we return positive data for compatibility."""
|
||||
return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)
|
||||
|
||||
|
||||
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
|
||||
def __init__(
|
||||
|
||||
@@ -133,24 +133,6 @@ class QwenPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
@@ -153,24 +153,6 @@ class QwenEditPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, batched_cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Any, Optional, final
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
@@ -22,16 +20,6 @@ from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class CfgBranch:
|
||||
positive: bool
|
||||
embeds: mx.array
|
||||
mask: mx.array | None
|
||||
pooled: mx.array | None
|
||||
cond_latents: mx.array | None
|
||||
|
||||
|
||||
def calculate_patch_heights(
|
||||
latent_height: int, num_patches: int
|
||||
) -> tuple[list[int], int]:
|
||||
@@ -84,11 +72,22 @@ class DiffusionRunner:
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
self._init_cfg_topology(shard_metadata)
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.num_patches = (
|
||||
num_patches if num_patches else max(1, self.pipeline_world_size)
|
||||
)
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
@@ -98,48 +97,6 @@ class DiffusionRunner:
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _init_cfg_topology(self, shard_metadata: PipelineShardMetadata) -> None:
|
||||
"""Initialize CFG and pipeline topology from shard metadata."""
|
||||
if self.group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.start_layer = 0
|
||||
self.end_layer = self.config.total_blocks
|
||||
|
||||
self.cfg_rank = 0
|
||||
self.cfg_world_size = 1
|
||||
self.cfg_parallel = False
|
||||
|
||||
self.pipeline_world_size = 1
|
||||
self.pipeline_rank = 0
|
||||
|
||||
self.next_pipeline_rank: int | None = None
|
||||
self.prev_pipeline_rank: int | None = None
|
||||
self.cfg_peer_rank: int | None = None
|
||||
self.first_pipeline_rank: int = 0
|
||||
self.last_pipeline_rank: int = 0
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.cfg_rank = shard_metadata.cfg_rank
|
||||
self.cfg_world_size = shard_metadata.cfg_world_size
|
||||
self.cfg_parallel = self.cfg_world_size > 1
|
||||
|
||||
self.pipeline_world_size = shard_metadata.pipeline_world_size
|
||||
self.pipeline_rank = shard_metadata.pipeline_rank
|
||||
|
||||
self.next_pipeline_rank = shard_metadata.next_pipeline_device
|
||||
self.prev_pipeline_rank = shard_metadata.prev_pipeline_device
|
||||
self.cfg_peer_rank = shard_metadata.cfg_peer_device
|
||||
|
||||
assert shard_metadata.first_pipeline_device is not None
|
||||
assert shard_metadata.last_pipeline_device is not None
|
||||
self.first_pipeline_rank = shard_metadata.first_pipeline_device
|
||||
self.last_pipeline_rank = shard_metadata.last_pipeline_device
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
@@ -176,11 +133,11 @@ class DiffusionRunner:
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.pipeline_rank == 0
|
||||
return self.rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.pipeline_rank == self.pipeline_world_size - 1
|
||||
return self.rank == self.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
@@ -191,97 +148,6 @@ class DiffusionRunner:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:
|
||||
"""Yield the CFG branches this node should process.
|
||||
|
||||
- No CFG: yields one branch (positive)
|
||||
- CFG parallel: yields one branch (our assigned branch)
|
||||
- Sequential CFG: yields two branches (positive, then negative)
|
||||
"""
|
||||
if not self.adapter.needs_cfg:
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
elif self.cfg_parallel:
|
||||
positive = self.cfg_rank == 0
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)
|
||||
yield CfgBranch(
|
||||
positive=positive,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
else:
|
||||
pos_embeds, pos_mask, pos_pooled, pos_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=True)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=pos_embeds,
|
||||
mask=pos_mask,
|
||||
pooled=pos_pooled,
|
||||
cond_latents=pos_cond,
|
||||
)
|
||||
neg_embeds, neg_mask, neg_pooled, neg_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=False)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=False,
|
||||
embeds=neg_embeds,
|
||||
mask=neg_mask,
|
||||
pooled=neg_pooled,
|
||||
cond_latents=neg_cond,
|
||||
)
|
||||
|
||||
def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:
|
||||
if len(results) == 1:
|
||||
positive, noise = results[0]
|
||||
if self.cfg_parallel and self.is_last_stage:
|
||||
# TODO(ciaran): try to remove
|
||||
mx.eval(noise)
|
||||
return self._exchange_and_apply_guidance(noise, positive)
|
||||
return noise
|
||||
|
||||
noise_neg = next(n for p, n in results if not p)
|
||||
noise_pos = next(n for p, n in results if p)
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _exchange_and_apply_guidance(
|
||||
self, noise: mx.array, is_positive: bool
|
||||
) -> mx.array:
|
||||
assert self.group is not None
|
||||
assert self.cfg_peer_rank is not None
|
||||
|
||||
if is_positive:
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_neg)
|
||||
noise_pos = noise
|
||||
else:
|
||||
noise_pos = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_pos)
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = noise
|
||||
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:
|
||||
scale = self._get_effective_guidance_scale()
|
||||
assert scale is not None
|
||||
return self.adapter.apply_guidance(noise_pos, noise_neg, scale)
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
@@ -598,9 +464,7 @@ class DiffusionRunner:
|
||||
) -> mx.array:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif (
|
||||
self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps
|
||||
):
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
@@ -624,29 +488,42 @@ class DiffusionRunner:
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
# Reset caches before each branch to ensure no state contamination
|
||||
self._reset_all_caches()
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
step_latents = mx.concatenate([latents, latents], axis=0)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = latents
|
||||
|
||||
noise = self._forward_pass(
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=cond_latents,
|
||||
)
|
||||
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=guidance_scale
|
||||
)
|
||||
|
||||
noise = self._forward_pass(
|
||||
latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=branch.cond_latents,
|
||||
)
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
noise = self._combine_cfg_results(results)
|
||||
return config.scheduler.step(noise=noise, timestep=t, latents=latents) # pyright: ignore[reportAny]
|
||||
|
||||
def _create_patches(
|
||||
@@ -697,7 +574,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
@@ -709,17 +586,16 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
@@ -744,30 +620,27 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_pipeline_rank, group=self.group
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
assert self.next_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
@@ -782,9 +655,8 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
@@ -807,65 +679,75 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
|
||||
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
|
||||
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
|
||||
cond_latents = branch.cond_latents
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
step_latents: mx.array = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
branch.mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
step_latents = mx.concatenate(
|
||||
[scaled_hidden_states, scaled_hidden_states], axis=0
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
encoder_mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
noise = self._combine_cfg_results(results)
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
|
||||
hidden_states = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise, timestep=t, latents=prev_latents
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.first_pipeline_rank, group=self.group
|
||||
)
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.last_pipeline_rank, group=self.group
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
@@ -884,10 +766,39 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
|
||||
for patch_idx in range(len(patch_latents)):
|
||||
@@ -899,52 +810,31 @@ class DiffusionRunner:
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.last_pipeline_rank, group=self.group
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
pooled_embeds = (
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(branch.mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
branch.embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=branch.embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=step_patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=prompt_embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
noise = self._combine_cfg_results(results)
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
|
||||
patch_latents[patch_idx] = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise,
|
||||
@@ -954,9 +844,7 @@ class DiffusionRunner:
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx],
|
||||
self.first_pipeline_rank,
|
||||
group=self.group,
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
@@ -996,12 +884,11 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
@@ -1010,7 +897,7 @@ class DiffusionRunner:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
@@ -1038,34 +925,29 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_pipeline_rank, group=self.group
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
@@ -1080,10 +962,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
|
||||
@@ -10,8 +10,11 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
CompletionTokensDetails,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
PromptTokensDetails,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
@@ -39,7 +42,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> float:
|
||||
) -> tuple[float, int]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
@@ -50,7 +53,7 @@ def prefill(
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0
|
||||
return 0.0, 0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
@@ -85,7 +88,7 @@ def prefill(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec
|
||||
return tokens_per_sec, num_tokens
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -169,6 +172,8 @@ def mlx_generate(
|
||||
mx.reset_peak_memory()
|
||||
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
|
||||
|
||||
logger.info(f"{is_bench=}")
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
logger.debug(f"task_params: {task}")
|
||||
|
||||
@@ -204,7 +209,9 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
@@ -212,28 +219,43 @@ def mlx_generate(
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
usage: Usage | None = None
|
||||
in_thinking = False
|
||||
reasoning_tokens = 0
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
for completion_tokens, out in enumerate(
|
||||
stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
if think_start is not None and out.text == think_start:
|
||||
in_thinking = True
|
||||
elif think_end is not None and out.text == think_end:
|
||||
in_thinking = False
|
||||
if in_thinking:
|
||||
reasoning_tokens += 1
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
@@ -245,11 +267,24 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
completion_tokens_details=CompletionTokensDetails(
|
||||
reasoning_tokens=reasoning_tokens
|
||||
),
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
|
||||
@@ -61,7 +61,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
DistributedImageModel,
|
||||
@@ -277,9 +277,11 @@ def main(
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
if (
|
||||
device_rank == 0
|
||||
and response.finish_reason == "error"
|
||||
@@ -307,6 +309,7 @@ def main(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
usage=response.usage,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
@@ -320,6 +323,7 @@ def main(
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
usage=response.usage,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -360,9 +364,8 @@ def main(
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -388,11 +391,7 @@ def main(
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -424,9 +423,8 @@ def main(
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -451,11 +449,7 @@ def main(
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -545,10 +539,10 @@ def parse_gpt_oss(
|
||||
name=current_tool_name,
|
||||
arguments="".join(tool_arg_parts).strip(),
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=response.usage,
|
||||
)
|
||||
tool_arg_parts = []
|
||||
break
|
||||
current_tool_name = recipient
|
||||
|
||||
# If inside a tool call, accumulate arguments
|
||||
@@ -694,7 +688,7 @@ def parse_tool_calls(
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
|
||||
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
|
||||
@@ -120,7 +120,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
@@ -182,6 +182,8 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
text="hi",
|
||||
token_id=0,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
stats=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
18
tmp/config_examples/opencode.json
Normal file
18
tmp/config_examples/opencode.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
|
||||
"provider": {
|
||||
"exo": {
|
||||
"api": "http://localhost:52415/v1",
|
||||
"models": {
|
||||
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
|
||||
"name": "GPT OSS 120B",
|
||||
"limit": {
|
||||
"context": 32768,
|
||||
"output": 8192
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
30
uv.lock
generated
30
uv.lock
generated
@@ -192,20 +192,14 @@ sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422 },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499 },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928 },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/43/0e822876f87ea8a4ef95442c3d766a06a51fc5298823f884ef87aaad168c/cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b", size = 220049 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/89/76799151d9c2d2d1ead63c2429da9ea9d7aac304603de0c6e8764e6e8e70/cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c", size = 207793 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/dd/3465b14bb9e24ee24cb88c9e3730f6de63111fffe513492bf8c808a3547e/cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef", size = 206300 },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/d9/d83e293854571c877a92da46fdec39158f8d7e68da75bf73581225d28e90/cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775", size = 219244 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/0f/1f177e3683aead2bb00f7679a16451d302c436b5cbf2505f0ea8146ef59e/cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205", size = 222828 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/0f/cafacebd4b040e3119dcb32fed8bdef8dfe94da653155f9d0b9dc660166e/cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1", size = 220926 },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/b4/c56878d0d1755cf9caa54ba71e5d049479c52f9e4afc230f06822162ab2f/cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c", size = 221593 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/0d/eb704606dfe8033e7128df5e90fee946bbcb64a04fcdaa97321309004000/cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8", size = 209354 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/19/3c435d727b368ca475fb8742ab97c9cb13a0de600ce86f62eab7fa3eea60/cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc", size = 208480 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/44/681604464ed9541673e486521497406fadcc15b5217c3e326b061696899a/cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592", size = 221584 },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/8e/342a504ff018a2825d395d44d63a767dd8ebc927ebda557fecdaca3ac33a/cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512", size = 224443 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/5e/b666bacbbc60fbf415ba9988324a132c9a7a0448a9a8f125074671c0f2c3/cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4", size = 223437 },
|
||||
@@ -311,10 +305,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/49/498c86566a1d80e978b42f0d702795f69887005548c041636df6ae1ca64c/cryptography-46.0.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:01ca9ff2885f3acc98c29f1860552e37f6d7c7d013d7334ff2a9de43a449315d", size = 4450807 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/0a/863a3604112174c8624a2ac3c038662d9e59970c7f926acdcfaed8d61142/cryptography-46.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6eae65d4c3d33da080cff9c4ab1f711b15c1d9760809dad6ea763f3812d254cb", size = 4299615 },
|
||||
{ url = "https://files.pythonhosted.org/packages/64/02/b73a533f6b64a69f3cd3872acb6ebc12aef924d8d103133bb3ea750dc703/cryptography-46.0.3-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5bf0ed4490068a2e72ac03d786693adeb909981cc596425d09032d372bcc849", size = 4016800 },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/d5/16e41afbfa450cde85a3b7ec599bebefaef16b5c6ba4ec49a3532336ed72/cryptography-46.0.3-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5ecfccd2329e37e9b7112a888e76d9feca2347f12f37918facbb893d7bb88ee8", size = 4984707 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a2c0cd47381a3229c403062f764160d57d4d175e022c1df84e168c6251a22eec", size = 4482541 },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/f6/50736d40d97e8483172f1bb6e698895b92a223dba513b0ca6f06b2365339/cryptography-46.0.3-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:549e234ff32571b1f4076ac269fcce7a808d3bf98b76c8dd560e42dbc66d7d91", size = 4299464 },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/de/d8e26b1a855f19d9994a19c702fa2e93b0456beccbcfe437eda00e0701f2/cryptography-46.0.3-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:c0a7bb1a68a5d3471880e264621346c48665b3bf1c3759d682fc0864c540bd9e", size = 4950838 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/29/798fc4ec461a1c9e9f735f2fc58741b0daae30688f41b2497dcbc9ed1355/cryptography-46.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:10b01676fc208c3e6feeb25a8b83d81767e8059e1fe86e1dc62d10a3018fa926", size = 4481596 },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/8d/03cd48b20a573adfff7652b76271078e3045b9f49387920e7f1f631d125e/cryptography-46.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0abf1ffd6e57c67e92af68330d05760b7b7efb243aab8377e583284dbab72c71", size = 4426782 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/b1/ebacbfe53317d55cf33165bda24c86523497a6881f339f9aae5c2e13e57b/cryptography-46.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a04bee9ab6a4da801eb9b51f1b708a1b5b5c9eb48c03f74198464c66f0d344ac", size = 4698381 },
|
||||
@@ -322,10 +314,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/fd/bc1daf8230eaa075184cbbf5f8cd00ba9db4fd32d63fb83da4671b72ed8a/cryptography-46.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39b6755623145ad5eff1dab323f4eae2a32a77a7abef2c5089a04a3d04366715", size = 4435078 },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/98/d3bd5407ce4c60017f8ff9e63ffee4200ab3e23fe05b765cab805a7db008/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:db391fa7c66df6762ee3f00c95a89e6d428f4d60e7abc8328f4fe155b5ac6e54", size = 4293460 },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/e9/e23e7900983c2b8af7a08098db406cf989d7f09caea7897e347598d4cd5b/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:78a97cf6a8839a48c49271cdcbd5cf37ca2c1d6b7fdd86cc864f302b5e9bf459", size = 3995237 },
|
||||
{ url = "https://files.pythonhosted.org/packages/91/15/af68c509d4a138cfe299d0d7ddb14afba15233223ebd933b4bbdbc7155d3/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:dfb781ff7eaa91a6f7fd41776ec37c5853c795d3b358d4896fdbb5df168af422", size = 4967344 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/e3/8643d077c53868b681af077edf6b3cb58288b5423610f21c62aadcbe99f4/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:6f61efb26e76c45c4a227835ddeae96d83624fb0d29eb5df5b96e14ed1a0afb7", size = 4466564 },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/43/c1e8726fa59c236ff477ff2b5dc071e54b21e5a1e51aa2cee1676f1c986f/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:23b1a8f26e43f47ceb6d6a43115f33a5a37d57df4ea0ca295b780ae8546e8044", size = 4292415 },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/f9/2f8fefdb1aee8a8e3256a0568cffc4e6d517b256a2fe97a029b3f1b9fe7e/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:b419ae593c86b87014b9be7396b385491ad7f320bde96826d0dd174459e54665", size = 4931457 },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/30/9b54127a9a778ccd6d27c3da7563e9f2d341826075ceab89ae3b41bf5be2/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:50fc3343ac490c6b08c0cf0d704e881d0d660be923fd3076db3e932007e726e3", size = 4466074 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/68/b4f4a10928e26c941b1b6a179143af9f4d27d88fe84a6a3c53592d2e76bf/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22d7e97932f511d6b0b04f2bfd818d73dcd5928db509460aaf48384778eb6d20", size = 4420569 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/49/3746dab4c0d1979888f125226357d3262a6dd40e114ac29e3d2abdf1ec55/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d55f3dffadd674514ad19451161118fd010988540cee43d8bc20675e775925de", size = 4681941 },
|
||||
@@ -333,10 +323,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/26/42/fa8389d4478368743e24e61eea78846a0006caffaf72ea24a15159215a14/cryptography-46.0.3-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:15ab9b093e8f09daab0f2159bb7e47532596075139dd74365da52ecc9cb46c5d", size = 4440029 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/eb/f483db0ec5ac040824f269e93dd2bd8a21ecd1027e77ad7bdf6914f2fd80/cryptography-46.0.3-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:46acf53b40ea38f9c6c229599a4a13f0d46a6c3fa9ef19fc1a124d62e338dfa0", size = 4297222 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/cf/da9502c4e1912cb1da3807ea3618a6829bee8207456fbbeebc361ec38ba3/cryptography-46.0.3-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10ca84c4668d066a9878890047f03546f3ae0a6b8b39b697457b7757aaf18dbc", size = 4012280 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/8f/9adb86b93330e0df8b3dcf03eae67c33ba89958fc2e03862ef1ac2b42465/cryptography-46.0.3-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:36e627112085bb3b81b19fed209c05ce2a52ee8b15d161b7c643a7d5a88491f3", size = 4978958 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/a0/5fa77988289c34bdb9f913f5606ecc9ada1adb5ae870bd0d1054a7021cc4/cryptography-46.0.3-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1000713389b75c449a6e979ffc7dcc8ac90b437048766cef052d4d30b8220971", size = 4473714 },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/e5/fc82d72a58d41c393697aa18c9abe5ae1214ff6f2a5c18ac470f92777895/cryptography-46.0.3-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:b02cf04496f6576afffef5ddd04a0cb7d49cf6be16a9059d793a30b035f6b6ac", size = 4296970 },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/06/5663ed35438d0b09056973994f1aec467492b33bd31da36e468b01ec1097/cryptography-46.0.3-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:71e842ec9bc7abf543b47cf86b9a743baa95f4677d22baa4c7d5c69e49e9bc04", size = 4940236 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/59/873633f3f2dcd8a053b8dd1d38f783043b5fce589c0f6988bf55ef57e43e/cryptography-46.0.3-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:402b58fc32614f00980b66d6e56a5b4118e6cb362ae8f3fda141ba4689bd4506", size = 4472642 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/39/8e71f3930e40f6877737d6f69248cf74d4e34b886a3967d32f919cc50d3b/cryptography-46.0.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ef639cb3372f69ec44915fafcd6698b6cc78fbe0c2ea41be867f6ed612811963", size = 4423126 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573 },
|
||||
@@ -415,7 +403,7 @@ requires-dist = [
|
||||
{ name = "mflux", specifier = "==0.15.4" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.4" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.4" },
|
||||
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.5" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
@@ -1073,7 +1061,7 @@ wheels = [
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.5"
|
||||
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#96699e6dadb13b82b28285bb131a0741997d19ae" }
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
@@ -1083,6 +1071,10 @@ dependencies = [
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/90/4469d9f75f196e6255f59a89441abe0079925d30a001462e1c1c4bc4e6a1/mlx_lm-0.30.5.tar.gz", hash = "sha256:9e6cb258c65b766c6af25cb90958aef40acab67139f05839eef19864cb3154f6", size = 262367 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/89/ba/66db6e1e5f1ef506655b562932f6bd8f72600116d5f31f92d71c1f200b3f/mlx_lm-0.30.5-py3-none-any.whl", hash = "sha256:a80bc8e3efdebe81813b0f6eb403fb66a7a15071e256f4e7102ada986acb75bb", size = 366716 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
@@ -1503,6 +1495,9 @@ version = "11.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/d0d6dea55cd152ce3d6767bb38a8fc10e33796ba4ba210cbab9354b6d238/pillow-11.3.0.tar.gz", hash = "sha256:3828ee7586cd0b2091b6209e5ad53e20d0649bbe87164a459d0676e035e8f523", size = 47113069 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/93/0952f2ed8db3a5a4c7a11f91965d6184ebc8cd7cbb7941a260d5f018cd2d/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:1c627742b539bba4309df89171356fcb3cc5a9178355b2727d1b74a6cf155fbd", size = 2128328 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/e8/100c3d114b1a0bf4042f27e0f87d2f25e857e838034e98ca98fe7b8c0a9c/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:30b7c02f3899d10f13d7a48163c8969e4e653f8b43416d23d13d1bbfdc93b9f8", size = 2170652 },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/86/3f758a28a6e381758545f7cdb4942e1cb79abd271bea932998fc0db93cb6/pillow-11.3.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:7859a4cc7c9295f5838015d8cc0a9c215b77e43d07a25e460f35cf516df8626f", size = 2227443 },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/f4/91d5b3ffa718df2f53b0dc109877993e511f4fd055d7e9508682e8aba092/pillow-11.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec1ee50470b0d050984394423d96325b744d55c701a439d2bd66089bff963d3c", size = 5278474 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/0e/37d7d3eca6c879fbd9dba21268427dffda1ab00d4eb05b32923d4fbe3b12/pillow-11.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7db51d222548ccfd274e4572fdbf3e810a5e66b00608862f947b163e613b67dd", size = 4686038 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/b0/3426e5c7f6565e752d81221af9d3676fdbb4f352317ceafd42899aaf5d8a/pillow-11.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2d6fcc902a24ac74495df63faad1884282239265c6839a0a6416d33faedfae7e", size = 5864407 },
|
||||
@@ -2278,7 +2273,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "5.0.0"
|
||||
version = "5.0.0rc3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -2287,14 +2282,15 @@ dependencies = [
|
||||
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "regex", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "safetensors", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tokenizers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bc/79/845941711811789c85fb7e2599cea425a14a07eda40f50896b9d3fda7492/transformers-5.0.0.tar.gz", hash = "sha256:5f5634efed6cf76ad068cc5834c7adbc32db78bbd6211fb70df2325a9c37dec8", size = 8424830 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3f/a3/7c116a8d85f69ea7749cf4c2df79e64c35d028e5fc7ea0168f299d03b8c7/transformers-5.0.0rc3.tar.gz", hash = "sha256:a0315b92b7e087617ade42ec9e6e92ee7620541cc5d6a3331886c52cbe306f5c", size = 8388520 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/52/f3/ac976fa8e305c9e49772527e09fbdc27cc6831b8a2f6b6063406626be5dd/transformers-5.0.0-py3-none-any.whl", hash = "sha256:587086f249ce64c817213cf36afdb318d087f790723e9b3d4500b97832afd52d", size = 10142091 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/f2/ae2b8968764253bdf38a48dee3c299b8d0bedf7c8ffbe3449fca9bd95338/transformers-5.0.0rc3-py3-none-any.whl", hash = "sha256:383fad27f4f73092d330e45fae384681e5c8521e1dc1cf6cb1a297780e68bf2d", size = 10107087 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user