mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-21 20:39:59 -05:00
Compare commits
5 Commits
leo/use-ml
...
tool-calli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5266f8d7ae | ||
|
|
75036ee9f6 | ||
|
|
023108a19d | ||
|
|
c9818c30b4 | ||
|
|
8f6726d6be |
@@ -364,7 +364,7 @@ The `exo-bench` tool measures model prefill and token generation speed across di
|
||||
|
||||
```bash
|
||||
uv run bench/exo_bench.py \
|
||||
--model llama-3.2-1b \
|
||||
--model Llama-3.2-1B-Instruct-4bit \
|
||||
--pp 128,256,512 \
|
||||
--tg 128,256
|
||||
```
|
||||
@@ -385,7 +385,7 @@ uv run bench/exo_bench.py \
|
||||
|
||||
```bash
|
||||
uv run bench/exo_bench.py \
|
||||
--model llama-3.2-1b \
|
||||
--model Llama-3.2-1B-Instruct-4bit \
|
||||
--pp 128,512 \
|
||||
--tg 128 \
|
||||
--max-nodes 2 \
|
||||
|
||||
@@ -195,14 +195,14 @@ def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m.get("hugging_face_id") or m["id"])
|
||||
if m.get("name").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
@@ -373,7 +373,7 @@ def main() -> int:
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": short_id}
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
|
||||
@@ -172,6 +172,33 @@
|
||||
}
|
||||
|
||||
let downloadOverview = $state<NodeEntry[]>([]);
|
||||
let models = $state<Array<{ id: string; storage_size_megabytes?: number }>>(
|
||||
[],
|
||||
);
|
||||
|
||||
async function fetchModels() {
|
||||
try {
|
||||
const response = await fetch("/models");
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
models = data.data || [];
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch models:", error);
|
||||
}
|
||||
}
|
||||
|
||||
function getModelTotalBytes(
|
||||
modelId: string,
|
||||
downloadTotalBytes: number,
|
||||
): number {
|
||||
if (downloadTotalBytes > 0) return downloadTotalBytes;
|
||||
const model = models.find((m) => m.id === modelId);
|
||||
if (model?.storage_size_megabytes) {
|
||||
return model.storage_size_megabytes * 1024 * 1024;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
try {
|
||||
@@ -346,6 +373,7 @@
|
||||
onMount(() => {
|
||||
// Ensure we fetch at least once when visiting downloads directly
|
||||
refreshState();
|
||||
fetchModels();
|
||||
});
|
||||
</script>
|
||||
|
||||
@@ -454,7 +482,7 @@
|
||||
{#if model.status !== "completed"}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(
|
||||
model.totalBytes,
|
||||
getModelTotalBytes(model.modelId, model.totalBytes),
|
||||
)}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
@@ -56,8 +57,15 @@ from exo.shared.types.api import (
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
@@ -93,7 +101,7 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -102,7 +110,19 @@ def chunk_to_response(
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text)
|
||||
if isinstance(chunk, TokenChunk)
|
||||
else ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
],
|
||||
),
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -162,8 +182,12 @@ class API:
|
||||
name="dashboard",
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
] = {}
|
||||
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
@@ -439,11 +463,13 @@ class API:
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
self._chat_completion_queues[command_id], recv = channel[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
@@ -462,7 +488,8 @@ class API:
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
if command_id in self._chat_completion_queues:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
@@ -470,6 +497,7 @@ class API:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
@@ -498,11 +526,12 @@ class API:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
@@ -511,7 +540,18 @@ class API:
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
@@ -529,6 +569,7 @@ class API:
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@@ -539,6 +580,7 @@ class API:
|
||||
self, command_id: CommandId
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
@@ -554,7 +596,19 @@ class API:
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
@@ -571,7 +625,7 @@ class API:
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=combined_text
|
||||
role="assistant", content=combined_text, tool_calls=tool_calls
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@@ -729,7 +783,9 @@ class API:
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
self._image_generation_queues[command_id], recv = channel[
|
||||
ImageChunk | ErrorChunk
|
||||
]()
|
||||
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
@@ -838,7 +894,9 @@ class API:
|
||||
stats: ImageGenerationStats | None = None
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
self._image_generation_queues[command_id], recv = channel[
|
||||
ImageChunk | ErrorChunk
|
||||
]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
@@ -994,7 +1052,6 @@ class API:
|
||||
await self._send(
|
||||
SendInputChunk(
|
||||
chunk=InputImageChunk(
|
||||
idx=chunk_index,
|
||||
model=resolved_model,
|
||||
command_id=command.command_id,
|
||||
data=chunk_data,
|
||||
@@ -1148,27 +1205,26 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if event.command_id in self._chat_completion_queues:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
elif event.command_id in self._image_generation_queues:
|
||||
if queue := self._image_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
queue = self._image_generation_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
if queue := self._chat_completion_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert not isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
# pyright: reportUnusedFunction=false, reportAny=false
|
||||
from typing import Any, get_args
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
|
||||
from exo.shared.types.chunks import ImageChunk, TokenChunk
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_http_exception_handler_formats_openai_style() -> None:
|
||||
"""Test that HTTPException is converted to OpenAI-style error format."""
|
||||
@@ -48,95 +44,3 @@ def test_http_exception_handler_formats_openai_style() -> None:
|
||||
assert data["error"]["message"] == "Resource not found"
|
||||
assert data["error"]["type"] == "Not Found"
|
||||
assert data["error"]["code"] == 404
|
||||
|
||||
|
||||
def test_finish_reason_includes_error() -> None:
|
||||
valid_reasons = get_args(FinishReason)
|
||||
assert "error" in valid_reasons
|
||||
|
||||
|
||||
def test_token_chunk_with_error_fields() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Something went wrong"
|
||||
|
||||
|
||||
def test_token_chunk_without_error() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=1,
|
||||
model=MODEL_A_ID,
|
||||
text="Hello",
|
||||
token_id=42,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
|
||||
|
||||
def test_error_response_construction() -> None:
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message="Generation failed",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
|
||||
assert error_response.error.message == "Generation failed"
|
||||
assert error_response.error.code == 500
|
||||
|
||||
|
||||
def test_normal_finish_reasons_still_work() -> None:
|
||||
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="done",
|
||||
token_id=100,
|
||||
finish_reason=reason, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunk.finish_reason == reason
|
||||
|
||||
|
||||
def test_image_chunk_with_error_fields() -> None:
|
||||
chunk = ImageChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
data="",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
finish_reason="error",
|
||||
error_message="Image generation failed",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Image generation failed"
|
||||
assert chunk.data == ""
|
||||
assert chunk.chunk_index == 0
|
||||
assert chunk.total_chunks == 1
|
||||
assert chunk.image_index == 0
|
||||
|
||||
|
||||
def test_image_chunk_without_error() -> None:
|
||||
chunk = ImageChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
data="base64encodeddata",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
assert chunk.data == "base64encodeddata"
|
||||
|
||||
@@ -59,8 +59,9 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
@staticmethod
|
||||
async def load(model_id: ModelId) -> "ModelCard":
|
||||
if model_id in MODEL_CARDS:
|
||||
return MODEL_CARDS[model_id]
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == model_id:
|
||||
return card
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
@staticmethod
|
||||
@@ -409,158 +410,159 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
# Image models commented out - feature not stable (see https://github.com/exo-explore/exo/issues/1242)
|
||||
# "flux1-schnell": ModelCard(
|
||||
# model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
# n_layers=57,
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder_2",
|
||||
# component_path="text_encoder_2/",
|
||||
# storage_size=Memory.from_bytes(9524621312),
|
||||
# n_layers=24,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename="model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(23782357120),
|
||||
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "flux1-dev": ModelCard(
|
||||
# model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
# n_layers=57,
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder_2",
|
||||
# component_path="text_encoder_2/",
|
||||
# storage_size=Memory.from_bytes(9524621312),
|
||||
# n_layers=24,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename="model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(23802816640),
|
||||
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "qwen-image": ModelCard(
|
||||
# model_id=ModelId("Qwen/Qwen-Image"),
|
||||
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(16584333312),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(40860802176),
|
||||
# n_layers=60,
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "qwen-image-edit-2509": ModelCard(
|
||||
# model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(16584333312),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(40860802176),
|
||||
# n_layers=60,
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -54,6 +54,18 @@ class ChatCompletionMessageText(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
index: int | None = None
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallItem
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
|
||||
content: (
|
||||
@@ -61,7 +73,7 @@ class ChatCompletionMessage(BaseModel):
|
||||
) = None
|
||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
||||
name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
@@ -8,24 +7,29 @@ from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .common import CommandId
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
Token = "Token"
|
||||
Image = "Image"
|
||||
from .worker.runner_response import ToolCallItem
|
||||
|
||||
|
||||
class BaseChunk(TaggedModel):
|
||||
idx: int
|
||||
model: ModelId
|
||||
|
||||
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
error_message: str
|
||||
finish_reason: Literal["error"] = "error"
|
||||
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
finish_reason: Literal["tool_calls"] = "tool_calls"
|
||||
stats: GenerationStats | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
@@ -63,4 +67,4 @@ class InputImageChunk(BaseChunk):
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
|
||||
from exo.shared.types.api import (
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -48,5 +53,9 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
yield name, value
|
||||
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -19,7 +19,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
model_card = await ModelCard.load(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Literal
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -13,11 +14,12 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
|
||||
from exo.shared.types.chunks import ImageChunk, TokenChunk
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
@@ -42,6 +44,8 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
ToolCallItem,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -154,6 +158,9 @@ def main(
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
@@ -244,17 +251,49 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if device_rank == 0:
|
||||
if (
|
||||
device_rank == 0
|
||||
and response.finish_reason == "error"
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif device_rank == 0:
|
||||
assert response.finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
@@ -263,6 +302,17 @@ def main(
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
@@ -270,11 +320,8 @@ def main(
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
@@ -328,18 +375,14 @@ def main(
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=0,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
data="",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
@@ -396,13 +439,8 @@ def main(
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=0,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
data="",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
@@ -446,6 +484,18 @@ def get_gpt_oss_encoding():
|
||||
return encoding
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
for resp in responses:
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
):
|
||||
continue
|
||||
yield resp
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
@@ -526,7 +576,6 @@ def _send_image_chunk(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
@@ -568,6 +617,113 @@ def _process_image_response(
|
||||
)
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tool_call_start:
|
||||
in_tool_call = True
|
||||
continue
|
||||
# assumption: the tool call end is one token
|
||||
if in_tool_call and response.text == tool_call_end:
|
||||
try:
|
||||
# tool_parser returns an arbitrarily nested python dictionary
|
||||
# we actually don't want the python dictionary, we just want to
|
||||
# parse the top level { function: ..., arguments: ... } structure
|
||||
# as we're just gonna hand it back to the api anyway
|
||||
parsed = tool_parser("".join(tool_call_text_parts).strip())
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if isinstance(parsed, list):
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
|
||||
)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
continue
|
||||
# fallthrough
|
||||
yield response
|
||||
|
||||
|
||||
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Version of to-be-upstreamed kimi-k2 tool parser
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
|
||||
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
|
||||
tool_call_start = "<|tool_call_begin|>"
|
||||
tool_call_end = "<|tool_call_end|>"
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = func_name[func_name.find(".") + 1 :]
|
||||
|
||||
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
and ((args := obj.get("arguments")) is not None)
|
||||
and isinstance(name, str)
|
||||
):
|
||||
return ToolCallItem(name=name, arguments=json.dumps(args))
|
||||
else:
|
||||
raise ValidationError
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
@@ -140,6 +140,13 @@ class EventCollector:
|
||||
pass
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
tool_parser = None
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -171,7 +178,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
expected_chunk = ChunkGenerated(
|
||||
command_id=COMMAND_1_ID,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="hi",
|
||||
token_id=0,
|
||||
|
||||
Reference in New Issue
Block a user