Compare commits

..

6 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
2f70419756 uv lock 2026-01-21 21:24:42 +00:00
Ryuichi Leo Takashige
5df602899f Add back custom fork 2026-01-21 21:23:59 +00:00
rltakashige
1ec791fe98 Merge branch 'main' into leo/use-mlx-branch 2026-01-21 21:21:31 +00:00
Ryuichi Leo Takashige
cba1afb75b Format 2026-01-21 21:20:57 +00:00
Ryuichi Leo Takashige
edaca941b9 Use all gather within pipeline last layer because that makes sense... 2026-01-21 21:20:57 +00:00
rltakashige
678d368651 Import download utils once all modules are loaded (#1238)
## Motivation

Test failed due to circular import

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Tried importing and calling the functions, worked fine.

### Automated Testing
Tests pass again
2026-01-21 20:50:34 +00:00
16 changed files with 362 additions and 548 deletions

View File

@@ -364,7 +364,7 @@ The `exo-bench` tool measures model prefill and token generation speed across di
```bash
uv run bench/exo_bench.py \
--model Llama-3.2-1B-Instruct-4bit \
--model llama-3.2-1b \
--pp 128,256,512 \
--tg 128,256
```
@@ -385,7 +385,7 @@ uv run bench/exo_bench.py \
```bash
uv run bench/exo_bench.py \
--model Llama-3.2-1B-Instruct-4bit \
--model llama-3.2-1b \
--pp 128,512 \
--tg 128 \
--max-nodes 2 \

View File

@@ -195,14 +195,14 @@ def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]
data = models.get("data") or []
for m in data:
if m.get("name").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
if m.get("id") == model_arg:
short_id = str(m["id"])
full_id = str(m.get("hugging_face_id") or m["id"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
short_id = str(m["id"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
@@ -373,7 +373,7 @@ def main() -> int:
short_id, full_model_id = resolve_model_short_id(client, args.model)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
"GET", "/instance/previews", params={"model_id": short_id}
)
previews = previews_resp.get("previews") or []

View File

@@ -172,33 +172,6 @@
}
let downloadOverview = $state<NodeEntry[]>([]);
let models = $state<Array<{ id: string; storage_size_megabytes?: number }>>(
[],
);
async function fetchModels() {
try {
const response = await fetch("/models");
if (response.ok) {
const data = await response.json();
models = data.data || [];
}
} catch (error) {
console.error("Failed to fetch models:", error);
}
}
function getModelTotalBytes(
modelId: string,
downloadTotalBytes: number,
): number {
if (downloadTotalBytes > 0) return downloadTotalBytes;
const model = models.find((m) => m.id === modelId);
if (model?.storage_size_megabytes) {
return model.storage_size_megabytes * 1024 * 1024;
}
return 0;
}
$effect(() => {
try {
@@ -373,7 +346,6 @@
onMount(() => {
// Ensure we fetch at least once when visiting downloads directly
refreshState();
fetchModels();
});
</script>
@@ -482,7 +454,7 @@
{#if model.status !== "completed"}
<div class="text-[11px] text-exo-light-gray font-mono">
{formatBytes(model.downloadedBytes)} / {formatBytes(
getModelTotalBytes(model.modelId, model.totalBytes),
model.totalBytes,
)}
</div>
{/if}

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer

View File

@@ -4,7 +4,6 @@ import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import Literal, cast
from uuid import uuid4
import anyio
from anyio import BrokenResourceError, create_task_group
@@ -57,15 +56,8 @@ from exo.shared.types.api import (
PlacementPreview,
PlacementPreviewResponse,
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -101,7 +93,7 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -110,19 +102,7 @@ def chunk_to_response(
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text)
if isinstance(chunk, TokenChunk)
else ChatCompletionMessage(
role="assistant",
tool_calls=[
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
],
),
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
],
@@ -182,12 +162,8 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
] = {}
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup | None = None
@@ -463,13 +439,11 @@ class API:
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
]()
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
async for chunk in token_chunks:
@@ -488,8 +462,7 @@ class API:
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
@@ -497,7 +470,6 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
@@ -526,12 +498,11 @@ class API:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ErrorChunk):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
@@ -540,18 +511,7 @@ class API:
if model is None:
model = chunk.model
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -569,7 +529,6 @@ class API:
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
@@ -580,7 +539,6 @@ class API:
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
@@ -596,19 +554,7 @@ class API:
if model is None:
model = chunk.model
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
text_parts.append(chunk.text)
stats = chunk.stats or stats
if chunk.finish_reason is not None:
@@ -625,7 +571,7 @@ class API:
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text, tool_calls=tool_calls
role="assistant", content=combined_text
),
finish_reason=finish_reason,
)
@@ -783,9 +729,7 @@ class API:
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[
ImageChunk | ErrorChunk
]()
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
with recv as chunks:
async for chunk in chunks:
@@ -894,9 +838,7 @@ class API:
stats: ImageGenerationStats | None = None
try:
self._image_generation_queues[command_id], recv = channel[
ImageChunk | ErrorChunk
]()
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
while images_complete < num_images:
with recv as chunks:
@@ -1052,6 +994,7 @@ class API:
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
@@ -1205,26 +1148,27 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
if queue := self._image_generation_queues.get(
event.command_id, None
):
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(
event.command_id, None
)
elif event.command_id in self._image_generation_queues:
assert isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
if queue := self._chat_completion_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
queue = self._image_generation_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -1,9 +1,13 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
@@ -44,3 +48,95 @@ def test_http_exception_handler_formats_openai_style() -> None:
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason
def test_image_chunk_with_error_fields() -> None:
chunk = ImageChunk(
idx=0,
model=MODEL_A_ID,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message="Image generation failed",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Image generation failed"
assert chunk.data == ""
assert chunk.chunk_index == 0
assert chunk.total_chunks == 1
assert chunk.image_index == 0
def test_image_chunk_without_error() -> None:
chunk = ImageChunk(
idx=0,
model=MODEL_A_ID,
data="base64encodeddata",
chunk_index=0,
total_chunks=1,
image_index=0,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
assert chunk.data == "base64encodeddata"

View File

@@ -59,9 +59,8 @@ class ModelCard(CamelCaseModel):
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
for card in MODEL_CARDS.values():
if card.model_id == model_id:
return card
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
return await ModelCard.from_hf(model_id)
@staticmethod
@@ -410,159 +409,158 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# Image models commented out - feature not stable (see https://github.com/exo-explore/exo/issues/1242)
# "flux1-schnell": ModelCard(
# model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
# n_layers=57,
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(0),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="text_encoder_2",
# component_path="text_encoder_2/",
# storage_size=Memory.from_bytes(9524621312),
# n_layers=24,
# can_shard=False,
# safetensors_index_filename="model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(23782357120),
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "flux1-dev": ModelCard(
# model_id=ModelId("black-forest-labs/FLUX.1-dev"),
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
# n_layers=57,
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(0),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="text_encoder_2",
# component_path="text_encoder_2/",
# storage_size=Memory.from_bytes(9524621312),
# n_layers=24,
# can_shard=False,
# safetensors_index_filename="model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(23802816640),
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "qwen-image": ModelCard(
# model_id=ModelId("Qwen/Qwen-Image"),
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(16584333312),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(40860802176),
# n_layers=60,
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "qwen-image-edit-2509": ModelCard(
# model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(16584333312),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(40860802176),
# n_layers=60,
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
}

View File

@@ -54,18 +54,6 @@ class ChatCompletionMessageText(BaseModel):
text: str
class ToolCallItem(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
index: int | None = None
type: Literal["function"] = "function"
function: ToolCallItem
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: (
@@ -73,7 +61,7 @@ class ChatCompletionMessage(BaseModel):
) = None
thinking: str | None = None # Added for GPT-OSS harmony format support
name: str | None = None
tool_calls: list[ToolCall] | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_call_id: str | None = None
function_call: dict[str, Any] | None = None

View File

@@ -1,4 +1,5 @@
from collections.abc import Generator
from enum import Enum
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
@@ -7,29 +8,24 @@ from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
from .worker.runner_response import ToolCallItem
class ChunkType(str, Enum):
Token = "Token"
Image = "Image"
class BaseChunk(TaggedModel):
idx: int
model: ModelId
class TokenChunk(BaseChunk):
text: str
token_id: int
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
class ErrorChunk(BaseChunk):
error_message: str
finish_reason: Literal["error"] = "error"
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
finish_reason: Literal["tool_calls"] = "tool_calls"
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):
@@ -67,4 +63,4 @@ class InputImageChunk(BaseChunk):
yield name, value
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,12 +1,7 @@
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import (
FinishReason,
GenerationStats,
ImageGenerationStats,
ToolCallItem,
)
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
from exo.utils.pydantic_ext import TaggedModel
@@ -53,9 +48,5 @@ class PartialImageResponse(BaseRunnerResponse):
yield name, value
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -19,7 +19,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.load(model_id)
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,

View File

@@ -145,6 +145,10 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output
@@ -252,10 +256,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call

View File

@@ -170,10 +170,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)

View File

@@ -1,9 +1,8 @@
import base64
import json
import time
from collections.abc import Generator
from functools import cache
from typing import Any, Callable, Literal
from typing import Literal
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -14,12 +13,11 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
@@ -44,8 +42,6 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
PartialImageResponse,
ToolCallItem,
ToolCallResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -158,9 +154,6 @@ def main(
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
@@ -251,49 +244,17 @@ def main(
mlx_generator, tokenizer
)
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
mlx_generator = parse_tool_calls(
mlx_generator,
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
# TODO: Add tool call parser here
for response in mlx_generator:
match response:
case GenerationResponse():
if (
device_rank == 0
and response.finish_reason == "error"
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
error_message=response.text,
model=shard_metadata.model_card.model_id,
),
)
)
elif device_rank == 0:
assert response.finish_reason not in (
"error",
"tool_calls",
"function_call",
)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
@@ -302,17 +263,6 @@ def main(
),
)
)
case ToolCallResponse():
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ToolCallChunk(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
),
)
)
# can we make this more explicit?
except Exception as e:
@@ -320,8 +270,11 @@ def main(
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_card.model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
@@ -375,14 +328,18 @@ def main(
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
chunk=ImageChunk(
idx=0,
model=shard_metadata.model_card.model_id,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message=str(e),
),
@@ -439,8 +396,13 @@ def main(
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
chunk=ImageChunk(
idx=0,
model=shard_metadata.model_card.model_id,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message=str(e),
),
@@ -484,18 +446,6 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
@@ -576,6 +526,7 @@ def _send_image_chunk(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
@@ -617,113 +568,6 @@ def _process_image_response(
)
def parse_tool_calls(
responses: Generator[GenerationResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> Generator[GenerationResponse | ToolCallResponse]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
except (json.JSONDecodeError, ValidationError) as e:
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
# strip off the `functions.` prefix, if it exists.
func_name = func_name[func_name.find(".") + 1 :]
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
return ToolCallItem(name=name, arguments=json.dumps(args))
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
@@ -140,13 +140,6 @@ class EventCollector:
pass
class MockTokenizer:
tool_parser = None
tool_call_start = None
tool_call_end = None
has_tool_calling = False
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -178,6 +171,7 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID,
chunk=TokenChunk(
idx=0,
model=MODEL_A_ID,
text="hi",
token_id=0,

59
uv.lock generated
View File

@@ -376,8 +376,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -413,7 +413,7 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = ">=0.14.2" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
@@ -458,16 +458,6 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"
@@ -1004,8 +994,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1032,18 +1022,12 @@ wheels = [
name = "mlx"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
resolution-markers = [
"sys_platform == 'linux'",
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
]
@@ -1056,6 +1040,14 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.4.dev20260121+fbe306f9"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.3"
@@ -1086,7 +1078,7 @@ version = "0.30.4"
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1094,16 +1086,6 @@ dependencies = [
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "mlx-metal"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
]
[[package]]
name = "more-itertools"
version = "10.8.0"
@@ -2227,6 +2209,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" },
]
[[package]]
name = "torch"
version = "2.9.1"