Compare commits

..

5 Commits

Author SHA1 Message Date
Evan
5266f8d7ae add kimi tool parseing
this patches the kimi tokenizer to add tool calling - it can be reverted
once upstream support is added for kimi-k2
2026-01-22 00:34:23 +00:00
Evan
75036ee9f6 implement mlx-lm tool calling
splits up the runners generation chunks into tool calls, tokens and
errors, and writes tool call chunks when the upstream parser detects
them.
2026-01-22 00:34:23 +00:00
Alex Cheema
023108a19d Disable image model cards temporarily (#1247)
## Motivation

Image generation feature is not stable and causing issues for users.

Fixes #1242

## Changes

- Commented out image model cards (flux1-schnell, flux1-dev, qwen-image,
qwen-image-edit-2509) in `src/exo/shared/models/model_cards.py`
- Added reference to issue #1242 in the comment explaining why they are
disabled

## Why It Works

By commenting out the model cards, these image models will no longer
appear in the model list, preventing users from attempting to use the
unstable feature until it is stabilized.

## Test Plan

### Manual Testing
- Run exo and verify image models no longer appear in the model list

### Automated Testing
- No changes to automated tests needed - this simply removes models from
the available list

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 22:39:59 +00:00
Jake Hillion
c9818c30b4 dashboard: show model total size on downloads page for pending downloads
The downloads page showed "0B / 0B" for models that haven't started
downloading yet because the download progress data only gets populated
after the file list is fetched from HuggingFace.

Added a fetch to the /models API endpoint on page mount and created a
helper function that falls back to storage_size_megabytes when the
download's totalBytes is 0.

This allows users to see the actual model size (e.g., "0 / 25GB")
before a download begins, which is helpful for a future feature that
lets you download models explicitly.

Test plan:
- Deployed to a cluster, the previous 0B now show sensible values.
2026-01-21 21:53:54 +00:00
Alex Cheema
8f6726d6be Fix config.json download errors for image models (#1245)
## Motivation

When `get_shard_download_status()` runs, it iterates over all models in
`MODEL_CARDS` and calls `build_full_shard()` → `build_base_shard()` →
`ModelCard.from_hf()`. This unconditionally tried to download
`config.json` from HuggingFace, but image models (FLUX, Qwen-Image)
don't have a root-level config.json file, causing errors:

```
Error downloading shard: File not found: https://huggingface.co/black-forest-labs/FLUX.1-dev/resolve/main/config.json
Error downloading shard: File not found: https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/config.json
Error downloading shard: File not found: https://huggingface.co/Qwen/Qwen-Image/resolve/main/config.json
Error downloading shard: File not found: https://huggingface.co/Qwen/Qwen-Image-Edit-2509/resolve/main/config.json
```

## Changes

### ModelCard.load() fix
- `build_base_shard()` now uses `ModelCard.load()` instead of
`ModelCard.from_hf()`
- `ModelCard.load()` iterates through `MODEL_CARDS.values()` to find a
match by `model_id`

### exo-bench fixes
- Use `name` field instead of `id` for model resolution
- Pass `full_model_id` to `/instance/previews` endpoint
- Make model name matching case-insensitive
- Update README example model name

## Why It Works

`MODEL_CARDS` uses short names as keys (e.g., `"flux1-schnell"`) but the
`model_id` values are HuggingFace paths (e.g.,
`"black-forest-labs/FLUX.1-schnell"`). When `ModelCard.load()` was
called with the HF path, it didn't match any key and fell back to
`from_hf()` which tried to download config.json.

The fix iterates through `MODEL_CARDS.values()` to find a match by
`model_id`, ensuring predefined models (including image models) use
their registry entries directly without network calls. A key lookup is
unnecessary since `load()` is always called with HF paths which don't
match the short-name keys.

## Test Plan

### Manual Testing
- Run exo and verify no more "Error downloading shard: File not found:
.../config.json" errors for image models
- Run exo-bench and verify model resolution works correctly

### Automated Testing
- `uv run basedpyright` - passes with 0 errors
- `uv run pytest` - all tests pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 21:30:48 +00:00
16 changed files with 548 additions and 362 deletions

View File

@@ -364,7 +364,7 @@ The `exo-bench` tool measures model prefill and token generation speed across di
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--model Llama-3.2-1B-Instruct-4bit \
--pp 128,256,512 \
--tg 128,256
```
@@ -385,7 +385,7 @@ uv run bench/exo_bench.py \
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--model Llama-3.2-1B-Instruct-4bit \
--pp 128,512 \
--tg 128 \
--max-nodes 2 \

View File

@@ -195,14 +195,14 @@ def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]
data = models.get("data") or []
for m in data:
if m.get("id") == model_arg:
short_id = str(m["id"])
full_id = str(m.get("hugging_face_id") or m["id"])
if m.get("name").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["id"])
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
@@ -373,7 +373,7 @@ def main() -> int:
short_id, full_model_id = resolve_model_short_id(client, args.model)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": short_id}
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []

View File

@@ -172,6 +172,33 @@
}
let downloadOverview = $state<NodeEntry[]>([]);
let models = $state<Array<{ id: string; storage_size_megabytes?: number }>>(
[],
);
async function fetchModels() {
try {
const response = await fetch("/models");
if (response.ok) {
const data = await response.json();
models = data.data || [];
}
} catch (error) {
console.error("Failed to fetch models:", error);
}
}
function getModelTotalBytes(
modelId: string,
downloadTotalBytes: number,
): number {
if (downloadTotalBytes > 0) return downloadTotalBytes;
const model = models.find((m) => m.id === modelId);
if (model?.storage_size_megabytes) {
return model.storage_size_megabytes * 1024 * 1024;
}
return 0;
}
$effect(() => {
try {
@@ -346,6 +373,7 @@
onMount(() => {
// Ensure we fetch at least once when visiting downloads directly
refreshState();
fetchModels();
});
</script>
@@ -454,7 +482,7 @@
{#if model.status !== "completed"}
<div class="text-[11px] text-exo-light-gray font-mono">
{formatBytes(model.downloadedBytes)} / {formatBytes(
model.totalBytes,
getModelTotalBytes(model.modelId, model.totalBytes),
)}
</div>
{/if}

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer

View File

@@ -4,6 +4,7 @@ import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import Literal, cast
from uuid import uuid4
import anyio
from anyio import BrokenResourceError, create_task_group
@@ -56,8 +57,15 @@ from exo.shared.types.api import (
PlacementPreview,
PlacementPreviewResponse,
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -93,7 +101,7 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -102,7 +110,19 @@ def chunk_to_response(
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
delta=ChatCompletionMessage(role="assistant", content=chunk.text)
if isinstance(chunk, TokenChunk)
else ChatCompletionMessage(
role="assistant",
tool_calls=[
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
],
),
finish_reason=chunk.finish_reason,
)
],
@@ -162,8 +182,12 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._chat_completion_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
] = {}
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup | None = None
@@ -439,11 +463,13 @@ class API:
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
self._chat_completion_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
@@ -462,7 +488,8 @@ class API:
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
del self._chat_completion_queues[command_id]
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
@@ -470,6 +497,7 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
@@ -498,11 +526,12 @@ class API:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
if isinstance(chunk, ErrorChunk):
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
@@ -511,7 +540,18 @@ class API:
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -529,6 +569,7 @@ class API:
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
@@ -539,6 +580,7 @@ class API:
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
@@ -554,7 +596,19 @@ class API:
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
stats = chunk.stats or stats
if chunk.finish_reason is not None:
@@ -571,7 +625,7 @@ class API:
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text
role="assistant", content=combined_text, tool_calls=tool_calls
),
finish_reason=finish_reason,
)
@@ -729,7 +783,9 @@ class API:
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
self._image_generation_queues[command_id], recv = channel[
ImageChunk | ErrorChunk
]()
with recv as chunks:
async for chunk in chunks:
@@ -838,7 +894,9 @@ class API:
stats: ImageGenerationStats | None = None
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
self._image_generation_queues[command_id], recv = channel[
ImageChunk | ErrorChunk
]()
while images_complete < num_images:
with recv as chunks:
@@ -994,7 +1052,6 @@ class API:
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
@@ -1148,27 +1205,26 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(
event.command_id, None
)
elif event.command_id in self._image_generation_queues:
if queue := self._image_generation_queues.get(
event.command_id, None
):
assert isinstance(event.chunk, ImageChunk)
queue = self._image_generation_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
if queue := self._chat_completion_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -1,13 +1,9 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from typing import Any
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
@@ -48,95 +44,3 @@ def test_http_exception_handler_formats_openai_style() -> None:
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason
def test_image_chunk_with_error_fields() -> None:
chunk = ImageChunk(
idx=0,
model=MODEL_A_ID,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message="Image generation failed",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Image generation failed"
assert chunk.data == ""
assert chunk.chunk_index == 0
assert chunk.total_chunks == 1
assert chunk.image_index == 0
def test_image_chunk_without_error() -> None:
chunk = ImageChunk(
idx=0,
model=MODEL_A_ID,
data="base64encodeddata",
chunk_index=0,
total_chunks=1,
image_index=0,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
assert chunk.data == "base64encodeddata"

View File

@@ -59,8 +59,9 @@ class ModelCard(CamelCaseModel):
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
for card in MODEL_CARDS.values():
if card.model_id == model_id:
return card
return await ModelCard.from_hf(model_id)
@staticmethod
@@ -409,158 +410,159 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
# Image models commented out - feature not stable (see https://github.com/exo-explore/exo/issues/1242)
# "flux1-schnell": ModelCard(
# model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
# n_layers=57,
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(0),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="text_encoder_2",
# component_path="text_encoder_2/",
# storage_size=Memory.from_bytes(9524621312),
# n_layers=24,
# can_shard=False,
# safetensors_index_filename="model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(23782357120),
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "flux1-dev": ModelCard(
# model_id=ModelId("black-forest-labs/FLUX.1-dev"),
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
# n_layers=57,
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(0),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="text_encoder_2",
# component_path="text_encoder_2/",
# storage_size=Memory.from_bytes(9524621312),
# n_layers=24,
# can_shard=False,
# safetensors_index_filename="model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(23802816640),
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "qwen-image": ModelCard(
# model_id=ModelId("Qwen/Qwen-Image"),
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(16584333312),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(40860802176),
# n_layers=60,
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "qwen-image-edit-2509": ModelCard(
# model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(16584333312),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(40860802176),
# n_layers=60,
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
}

View File

@@ -54,6 +54,18 @@ class ChatCompletionMessageText(BaseModel):
text: str
class ToolCallItem(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
index: int | None = None
type: Literal["function"] = "function"
function: ToolCallItem
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: (
@@ -61,7 +73,7 @@ class ChatCompletionMessage(BaseModel):
) = None
thinking: str | None = None # Added for GPT-OSS harmony format support
name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None
function_call: dict[str, Any] | None = None

View File

@@ -1,5 +1,4 @@
from collections.abc import Generator
from enum import Enum
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
@@ -8,24 +7,29 @@ from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
class ChunkType(str, Enum):
Token = "Token"
Image = "Image"
from .worker.runner_response import ToolCallItem
class BaseChunk(TaggedModel):
idx: int
model: ModelId
class TokenChunk(BaseChunk):
text: str
token_id: int
finish_reason: FinishReason | None = None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
class ErrorChunk(BaseChunk):
error_message: str
finish_reason: Literal["error"] = "error"
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
finish_reason: Literal["tool_calls"] = "tool_calls"
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):
@@ -63,4 +67,4 @@ class InputImageChunk(BaseChunk):
yield name, value
GenerationChunk = TokenChunk | ImageChunk
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk

View File

@@ -1,7 +1,12 @@
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
from exo.shared.types.api import (
FinishReason,
GenerationStats,
ImageGenerationStats,
ToolCallItem,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -48,5 +53,9 @@ class PartialImageResponse(BaseRunnerResponse):
yield name, value
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -19,7 +19,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
model_card = await ModelCard.load(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,

View File

@@ -145,10 +145,6 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output
@@ -256,6 +252,10 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call

View File

@@ -170,10 +170,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)

View File

@@ -1,8 +1,9 @@
import base64
import json
import time
from collections.abc import Generator
from functools import cache
from typing import Literal
from typing import Any, Callable, Literal
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -13,11 +14,12 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
@@ -42,6 +44,8 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
PartialImageResponse,
ToolCallItem,
ToolCallResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -154,6 +158,9 @@ def main(
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
@@ -244,17 +251,49 @@ def main(
mlx_generator, tokenizer
)
# TODO: Add tool call parser here
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
mlx_generator = parse_tool_calls(
mlx_generator,
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
for response in mlx_generator:
match response:
case GenerationResponse():
if device_rank == 0:
if (
device_rank == 0
and response.finish_reason == "error"
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
error_message=response.text,
model=shard_metadata.model_card.model_id,
),
)
)
elif device_rank == 0:
assert response.finish_reason not in (
"error",
"tool_calls",
"function_call",
)
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
@@ -263,6 +302,17 @@ def main(
),
)
)
case ToolCallResponse():
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ToolCallChunk(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
),
)
)
# can we make this more explicit?
except Exception as e:
@@ -270,11 +320,8 @@ def main(
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
@@ -328,18 +375,14 @@ def main(
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=0,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message=str(e),
),
@@ -396,13 +439,8 @@ def main(
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=0,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message=str(e),
),
@@ -446,6 +484,18 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
@@ -526,7 +576,6 @@ def _send_image_chunk(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
@@ -568,6 +617,113 @@ def _process_image_response(
)
def parse_tool_calls(
responses: Generator[GenerationResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> Generator[GenerationResponse | ToolCallResponse]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
except (json.JSONDecodeError, ValidationError) as e:
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
# strip off the `functions.` prefix, if it exists.
func_name = func_name[func_name.find(".") + 1 :]
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
return ToolCallItem(name=name, arguments=json.dumps(args))
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
@@ -140,6 +140,13 @@ class EventCollector:
pass
class MockTokenizer:
tool_parser = None
tool_call_start = None
tool_call_end = None
has_tool_calling = False
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -171,7 +178,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID,
chunk=TokenChunk(
idx=0,
model=MODEL_A_ID,
text="hi",
token_id=0,

59
uv.lock generated
View File

@@ -376,8 +376,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -413,7 +413,7 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = ">=0.14.2" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
@@ -458,6 +458,16 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"
@@ -994,8 +1004,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1022,12 +1032,18 @@ wheels = [
name = "mlx"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform == 'linux'",
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
]
@@ -1040,14 +1056,6 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.4.dev20260121+fbe306f9"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.3"
@@ -1078,7 +1086,7 @@ version = "0.30.4"
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1086,6 +1094,16 @@ dependencies = [
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "mlx-metal"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
]
[[package]]
name = "more-itertools"
version = "10.8.0"
@@ -2209,15 +2227,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" },
]
[[package]]
name = "torch"
version = "2.9.1"