Compare commits

...

12 Commits

Author SHA1 Message Date
Andrei Onel
7a36d3968d docs: Update documentation for v1.0.68 release (#1667)
## Motivation

Updated documentation for v1.0.68 release

## Changes

**docs/api.md:**
- Added documentation for new API endpoints: Claude Messages API
(`/v1/messages`), OpenAI Responses API (`/v1/responses`), and Ollama API
compatibility endpoints
- Documented custom model management endpoints (`POST /models/add`,
`DELETE /models/custom/{model_id}`)
- Added `enable_thinking` parameter documentation for thinking-capable
models (DeepSeek V3.1, Qwen3, GLM-4.7)
- Documented usage statistics in responses (prompt_tokens,
completion_tokens, total_tokens)
- Added streaming event format documentation for all API types
- Updated image generation section with FLUX.1-Kontext-dev support and
new dimensions (1024x1365, 1365x1024)
- Added request cancellation documentation
- Updated complete endpoint summary with all new endpoints
- Added security notes about trust_remote_code being opt-in

**README.md:**
- Updated Features section to highlight multiple API compatibility
options
- Added Environment Variables section documenting all configuration
options (EXO_MODELS_PATH, EXO_OFFLINE, EXO_ENABLE_IMAGE_MODELS,
EXO_LIBP2P_NAMESPACE, EXO_FAST_SYNCH, EXO_TRACING_ENABLED)
- Expanded "Using the API" section with examples for Claude Messages
API, OpenAI Responses API, and Ollama API
- Added custom model loading documentation with security notes
- Updated file locations to include log files and custom model cards
paths

**CONTRIBUTING.md:** 
- Added documentation for TOML model cards format and the API adapter
pattern

**docs/architecture.md:**
- Documented the adapter architecture introduced in PR #1167

Closes #1653

---------

Co-authored-by: askmanu[bot] <192355599+askmanu[bot]@users.noreply.github.com>
Co-authored-by: Evan Quiney <evanev7@gmail.com>
2026-03-06 11:32:46 +00:00
Mustafa Alp Yılmaz
eee3432738 feat(mlx): add repetition_penalty and repetition_context_size to chat completions (#1665)
## Problem

Models running through exo can get stuck in repetition loops —
generating the same text over and over until hitting the token limit. It
happens more with quantized models where probability distributions can
become degenerate.

`mlx_lm` has `make_logits_processors(repetition_penalty,
repetition_context_size)` that handles this, but it is never called
anywhere in the pipeline.

## Solution

Wire `repetition_penalty` and `repetition_context_size` through the
stack:

- `ChatCompletionRequest` → accept the params from the client
- `TextGenerationTaskParams` → carry them through the pipeline
- `chat_completions.py` adapter → map to internal params
- `generate.py` → call `make_logits_processors()` and merge into the
`logits_processors` list

When `repetition_penalty` is `None` (default), `make_logits_processors`
returns `[]` so existing behavior is unchanged.

## Usage

```json
{
  "model": "mlx-community/...",
  "messages": [...],
  "repetition_penalty": 1.15,
  "repetition_context_size": 64
}
```
2026-03-06 10:51:25 +00:00
Alex Cheema
e8c3a873a6 feat(dashboard): improve model picker feedback and HuggingFace search (#1661)
## Motivation

Fixes #1648. Users adding custom models from HuggingFace got no clear
feedback that the model was successfully added — the model didn't appear
prominently in the list. Additionally, searches were limited to
`mlx-community` models with no fallback.

## Changes

- **Success feedback**: After adding a custom model, a toast
notification appears, the view auto-switches to "All Models", and the
newly added model is scrolled into view with a green highlight that
fades over 4 seconds.
- **HuggingFace search fallback**: The `/models/search` endpoint now
searches `mlx-community` first; if no results are found, it falls back
to searching all of HuggingFace.
- **Inline HF results**: When the main search bar finds no local
matches, HuggingFace search results appear inline with "+ Add" buttons
and a "See all results on Hub" link.
- **Full repo ID for non-mlx models**: Non-mlx-community models now
display the full repo ID (e.g., `meta-llama/Llama-3.1-8B`) instead of
just the short name.

## Why It Works

The toast + scroll + highlight gives immediate, unambiguous feedback
that the model was added and where it lives in the list. The HF search
fallback broadens discoverability while still prioritising mlx-community
models. Inline results in the main search bar mean users don't need to
navigate to the HF tab to discover new models.

## Test Plan

### Manual Testing
- Open model picker → HuggingFace Hub tab → search for a model → click
"+ Add"
- Verify: toast appears, view switches to All Models, model highlighted
with green glow
- Search for a term with no mlx-community results → verify fallback to
full HF search
- Non-mlx-community results show full repo ID (e.g., `org/model-name`)
- On All Models tab, search for a model not in the local list → verify
"From HuggingFace" section appears

### Automated Testing
- Existing tests cover the model catalog and API; no new automated tests
needed for UI behaviour changes

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-03-06 10:23:08 +00:00
Michael Harrigan
afab3095b0 fix: KVPrefixCache Regression (#1668) 2026-03-06 10:11:08 +00:00
ciaranbor
b9d40e8e35 Ciaran/re download bug (#1658)
## Motivation

After deleting a model and re-downloading it, the CachedShardDownloader
returns the stale cached path, so ensure_shard short-circuits and no
download actually happens.

## Changes

- Added invalidate(model_id) method to the ShardDownloader ABC and all
implementations
- CachedShardDownloader.invalidate evicts cache entries matching the
model ID and delegates down
- DownloadCoordinator.delete_model calls invalidate after cancelling
active downloads, before deleting files
- Added end-to-end test that downloads, deletes, and re-downloads a
model through the coordinator

## Why It Works

The cache is cleared when a model is deleted, so the next ensure_shard
call performs a fresh download instead of returning the stale path.

## Test Plan

## Automated Testing

New test_re_download_after_delete_completes exercises the full download
→ delete → re-download flow through DownloadCoordinator with
CachedShardDownloader + SingletonShardDownloader wrappers matching
production.
2026-03-05 14:18:17 +00:00
wysie
3a4d635d0c Fix copy code button not working in dashboard (#1659)
Fixes #1657

## Motivation

The copy button on code blocks in the dashboard does nothing when
clicked. This affects all code blocks in assistant responses.

## Root Cause

In `MarkdownContent.svelte`, click event listeners are bound to
`.copy-code-btn` elements via `setupCopyButtons()` inside a Svelte 5
`$effect`. However, the effect fires before the DOM has been updated
with the new HTML, so `querySelectorAll(".copy-code-btn")` finds zero
buttons.

Additionally, during streaming, the `content` prop updates on every
token, causing the entire `{@html processedHtml}` to be re-rendered.
This destroys all previously bound event listeners, even if they were
successfully attached.

## Changes

Replaced the per-button `addEventListener` approach with **event
delegation** — a single click listener on the container element that
catches clicks bubbling up from any `.copy-code-btn` or
`.copy-math-btn`. This:

- Eliminates the timing issue (the listener exists before the buttons
are rendered)
- Survives HTML re-renders during streaming (no need to re-bind)
- Removes the need for `setupCopyButtons()` and the `data-listenerBound`
tracking

## Testing

1. Load any model
2. Prompt it to generate a code block (e.g. "write a hello world in
Python")
3. Click the copy button on the code block
4. Paste — the code is copied correctly
5. Verified the button also works during streaming (before generation
completes)

Co-authored-by: Wysie <wysie@users.noreply.github.com>
2026-03-05 12:41:17 +00:00
Miguel Miranda Dias
8485805042 fix(worker): emit error chunks when a runner dies mid-command (#1645)
Closes #1586

## Summary
- track in-flight tasks in `RunnerSupervisor` (not only unacknowledged
pending tasks)
- when `_check_runner()` detects a crashed runner, emit
`ChunkGenerated(ErrorChunk)` for each in-flight command task
(`TextGeneration`, `ImageGeneration`, `ImageEdits`)
- keep existing `RunnerStatusUpdated(RunnerFailed)` emission so
planner/state still transition correctly
- add a unit test for supervisor crash path to ensure an error chunk is
emitted before failed runner status

## Why
`#1586` reports streams that can hang forever when runners crash during
warmup/loading. This keeps failure signaling at the runner-supervisor
layer, matching maintainer guidance in the issue thread.

## Validation
- attempted: `uv run pytest
src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py`
- blocked locally by environment disk exhaustion while uv tried to
materialize heavy CUDA wheels (`No space left on device` during
`nvidia-cudnn-cu13` extraction)

I kept the change scoped and added a targeted unit test for the failure
path.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-03-04 17:15:58 +00:00
ciaranbor
4de8f801c7 #Add reasoning parms to chat completion and responses APIs (#1654)
## Motivation

Adds reasoning_effort parameter support to both the Chat Completions and
Responses APIs, aligning with the OpenAI spec and enabling thinking
control for gpt-oss models

## Changes

- Added ReasoningEffort literal type ("none" | "minimal" | "low" |
"medium" | "high" | "xhigh") and a resolve_reasoning_params() helper
that cross-derives reasoning_effort ↔ enable_thinking when only one is
provided
- Added reasoning_effort field to ChatCompletionRequest and reasoning
(with Reasoning model) + enable_thinking to ResponsesRequest
- Both adapters now call resolve_reasoning_params() before building
TextGenerationTaskParams
- reasoning_effort is passed through to the MLX chat template as a
template variable

## Why It Works

resolve_reasoning_params is a pure function that normalises the two
overlapping knobs (reasoning_effort and enable_thinking) into a
consistent pair, so downstream code always has both values regardless of
which the caller supplied.

## Test Plan

### Automated Testing

Added test_resolve_reasoning_params.py with 10 test cases covering:
both-None, both-set passthrough, enable_thinking → effort derivation,
and effort → enable_thinking derivation for every ReasoningEffort
variant.
2026-03-04 14:27:49 +00:00
Owleksiy
5777bf3c39 fix: coerce tool-call argument types from tool schema (#1651)
Apply schema-aware coercion to parsed tool-call arguments so
Hermes-style toolcalls can still return typed JSON (e.g. integer ids).

 - pass request tools into parse_tool_calls
 - coerce parsed argument values by function parameters schema
 - add unit tests for coercion and unknown-tool passthrough

## Motivation

Models that use Hermes-based toolcall syntax (Qwen3.5) can't reliably
call tools with non-string parameters
Example tool:
```json
    {
      "type": "function",
      "function": {
        "name": "process",
        "description": "Manage background processes",
        "parameters": {
          "type": "object",
          "properties": {
            "action": {
              "type": "string",
              "enum": ["spawn", "output", "kill", "list"]
            },
            "id": {
              "type": "integer",
              "description": "Process id"
            },
            "command": {
              "type": "string",
              "description": "Command to run for spawn"
            }
          },
          "required": ["action"],
          "additionalProperties": false
        }
      }
    }
```
Model transcript:
```
<tool_call>
<function=process>
<parameter=action>
output
</parameter>
<parameter=id>
0
</parameter>
</function>
</tool_call>
```
And the API returns:

`{"id":"a8f11689-d840-4ca5-ab1d-ead3678a11a9","name":"process","arguments":"{\"action\":
\"output\", \"id\": \"0\"}"}}`

Tool definition declared `id` as `integer`, the model output is
type-agnostic, and the translation layer treats everything as a string.

The same Qwen3.5-27B on OpenRouter and GPT-4.1-mini on openai obey the
function signature and emit correct call:
```
{"name":"process","arguments":"{\"action\": \"output\", \"id\": 0}"}}
```

Steps to reproduce:
```
❯ curl -sS -v http://localhost:52415/v1/chat/completions \
  -H 'Content-Type: application/json' \  -d @- <<'JSON'
{
  "model": "mlx-community/Qwen3.5-27B-4bit",
  "stream": false,
  "temperature": 0,
  "messages": [
    {
      "role": "user",
      "content": "Call the process tool with action=output and id=0. Do not explain anything. Just make the tool call."
    }
  ],
  "tools": [
    {
      "type": "function",
      "function": {
        "name": "process",
        "description": "Manage background processes",
        "parameters": {
          "type": "object",
          "properties": {
            "action": {
              "type": "string",
              "enum": ["spawn", "output", "kill", "list"]
            },
            "id": {
              "type": "integer",
              "description": "Process id"
            },
            "command": {
              "type": "string",
              "description": "Command to run for spawn"
            }
          },
          "required": ["action"],
          "additionalProperties": false
        }
      }
    }
  ],
  "tool_choice": {
    "type": "function",
    "function": {
      "name": "process"
    }
  }
}
JSON
```
Look for type of `id` function call

## Changes

Function call parameters are now converted to the types that the
function declaration has

## Why It Works

We now explicitly convert types where we know it (and skip if we don't)

## Test Plan

### Manual Testing
Create Qwen3.5-<ANY> instance, send the curl command above. Check that
'id' is now serialized as a number

### Automated Testing
Unit tests to cover basic type conversions

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-03-04 12:22:53 +00:00
Evan Quiney
886192f1e6 ignore closed resource errors when trying to cancel a task (#1652)
fixes an occasional crash during the shutdown of a failed instance.
2026-03-03 16:48:43 +00:00
Evan Quiney
d914acd64e check if we have a task before we delete it (#1634)
caused a crash we should instead be logging
2026-03-03 15:32:12 +00:00
rltakashige
37296c8249 Refactor runner for implementing batching (#1632)
## Motivation

Batching will require us to send tasks concurrently and queue them up.
Our current infrastructure cannot handle that all. This PR gets us
closer to this by allowing multiple tasks to be sent in parallel and
then queuing up tasks.

## Changes

Change Plan logic
Make runner main into a class
Add a "BatchGenerator" to which tasks can be submitted (although tasks
are handled sequentially) and sent back through an MpSender.
Refactor runner to accept tasks during generation
Keep the generator threading
Separate the runner into several files for better readability

## Test Plan

### Manual Testing
Tested manually, needs a lot more automated testing. Cancellation still
works on a single device. Needs checking on multiple devices.

### Automated Testing

---------

Co-authored-by: Evan Quiney <evanev7@gmail.com>
2026-03-03 14:38:55 +00:00
37 changed files with 3046 additions and 1117 deletions

View File

@@ -48,7 +48,7 @@ def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = ..., logit_bias: Optional[Dict[int, float]] = ...,
repetition_penalty: Optional[float] = ..., repetition_penalty: Optional[float] = ...,
repetition_context_size: Optional[int] = ..., repetition_context_size: Optional[int] = ...,
): # -> list[Any]: ) -> list[Callable[[mx.array, mx.array], mx.array]]:
""" """
Make logits processors for use with ``generate_step``. Make logits processors for use with ``generate_step``.

View File

@@ -39,6 +39,119 @@ Write pure functions where possible. When adding new code, prefer Rust unless th
Run `nix fmt` to auto-format your code before submitting. Run `nix fmt` to auto-format your code before submitting.
## Model Cards
EXO uses TOML-based model cards to define model metadata and capabilities. Model cards are stored in:
- `resources/inference_model_cards/` for text generation models
- `resources/image_model_cards/` for image generation models
- `~/.exo/custom_model_cards/` for user-added custom models
### Adding a Model Card
To add a new model, create a TOML file with the following structure:
```toml
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.2 1B"
capabilities = ["text"]
[storage_size]
in_bytes = 729808896
```
### Required Fields
- `model_id`: Hugging Face model identifier
- `n_layers`: Number of transformer layers
- `hidden_size`: Hidden dimension size
- `supports_tensor`: Whether the model supports tensor parallelism
- `tasks`: List of supported tasks (`TextGeneration`, `TextToImage`, `ImageToImage`)
- `family`: Model family (e.g., "llama", "deepseek", "qwen")
- `quantization`: Quantization level (e.g., "4bit", "8bit", "bf16")
- `base_model`: Human-readable base model name
- `capabilities`: List of capabilities (e.g., `["text"]`, `["text", "thinking"]`)
### Optional Fields
- `components`: For multi-component models (like image models with separate text encoders and transformers)
- `uses_cfg`: Whether the model uses classifier-free guidance (for image models)
- `trust_remote_code`: Whether to allow remote code execution (defaults to `false` for security)
### Capabilities
The `capabilities` field defines what the model can do:
- `text`: Standard text generation
- `thinking`: Model supports chain-of-thought reasoning
- `thinking_toggle`: Thinking can be enabled/disabled via `enable_thinking` parameter
- `image_edit`: Model supports image-to-image editing (FLUX.1-Kontext)
### Security Note
By default, `trust_remote_code` is set to `false` for security. Only enable it if the model explicitly requires remote code execution from the Hugging Face hub.
## API Adapters
EXO supports multiple API formats through an adapter pattern. Adapters convert API-specific request formats to the internal `TextGenerationTaskParams` format and convert internal token chunks back to API-specific responses.
### Adapter Architecture
All adapters live in `src/exo/master/adapters/` and follow the same pattern:
1. Convert API-specific requests to `TextGenerationTaskParams`
2. Handle both streaming and non-streaming response generation
3. Convert internal `TokenChunk` objects to API-specific formats
4. Manage error handling and edge cases
### Existing Adapters
- `chat_completions.py`: OpenAI Chat Completions API
- `claude.py`: Anthropic Claude Messages API
- `responses.py`: OpenAI Responses API
- `ollama.py`: Ollama API (for OpenWebUI compatibility)
### Adding a New API Adapter
To add support for a new API format:
1. Create a new adapter file in `src/exo/master/adapters/`
2. Implement a request conversion function:
```python
def your_api_request_to_text_generation(
request: YourAPIRequest,
) -> TextGenerationTaskParams:
# Convert API request to internal format
pass
```
3. Implement streaming response generation:
```python
async def generate_your_api_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None],
) -> AsyncGenerator[str, None]:
# Convert internal chunks to API-specific streaming format
pass
```
4. Implement non-streaming response collection:
```python
async def collect_your_api_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None],
) -> AsyncGenerator[str]:
# Collect all chunks and return single response
pass
```
5. Register the adapter endpoints in `src/exo/master/api.py`
The adapter pattern keeps API-specific logic isolated from core inference systems. Internal systems (worker, runner, event sourcing) only see `TextGenerationTaskParams` and `TokenChunk` objects - no API-specific types cross the adapter boundary.
For detailed API documentation, see [docs/api.md](docs/api.md).
## Testing ## Testing
EXO relies heavily on manual testing at this point in the project, but this is evolving. Before submitting a change, test both before and after to demonstrate how your change improves behavior. Do the best you can with the hardware you have available - if you need help testing, ask and we'll do our best to assist. Add automated tests where possible - we're actively working to substantially improve our automated testing story. EXO relies heavily on manual testing at this point in the project, but this is evolving. Before submitting a change, test both before and after to demonstrate how your change improves behavior. Do the best you can with the hardware you have available - if you need help testing, ask and we'll do our best to assist. Add automated tests where possible - we're actively working to substantially improve our automated testing story.

118
README.md
View File

@@ -26,6 +26,8 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Topology-Aware Auto Parallel**: exo figures out the best way to split your model across all available devices based on a realtime view of your device topology. It takes into account device resources and network latency/bandwidth between each link. - **Topology-Aware Auto Parallel**: exo figures out the best way to split your model across all available devices based on a realtime view of your device topology. It takes into account device resources and network latency/bandwidth between each link.
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices. - **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication. - **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
- **Multiple API Compatibility**: Compatible with OpenAI Chat Completions API, Claude Messages API, OpenAI Responses API, and Ollama API - use your existing tools and clients.
- **Custom Model Support**: Load custom models from HuggingFace hub to expand the range of available models.
## Dashboard ## Dashboard
@@ -196,6 +198,8 @@ exo follows the [XDG Base Directory Specification](https://specifications.freede
- **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`) - **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`)
- **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`) - **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`)
- **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`) - **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`)
- **Log files**: `~/.cache/exo/exo_log/` (with automatic log rotation)
- **Custom model cards**: `~/.local/share/exo/custom_model_cards/`
You can override these locations by setting the corresponding XDG environment variables. You can override these locations by setting the corresponding XDG environment variables.
@@ -275,8 +279,47 @@ After that, RDMA will be enabled in macOS and exo will take care of the rest.
--- ---
## Environment Variables
exo supports several environment variables for configuration:
| Variable | Description | Default |
|----------|-------------|---------|
| `EXO_MODELS_PATH` | Colon-separated paths to search for pre-downloaded models (e.g., on NFS mounts or shared storage) | None |
| `EXO_MODELS_DIR` | Directory where exo downloads and stores models | `~/.local/share/exo/models` (Linux) or `~/.exo/models` (macOS) |
| `EXO_OFFLINE` | Run without internet connection (uses only local models) | `false` |
| `EXO_ENABLE_IMAGE_MODELS` | Enable image model support | `false` |
| `EXO_LIBP2P_NAMESPACE` | Custom namespace for cluster isolation | None |
| `EXO_FAST_SYNCH` | Control MLX_METAL_FAST_SYNCH behavior (for JACCL backend) | Auto |
| `EXO_TRACING_ENABLED` | Enable distributed tracing for performance analysis | `false` |
**Example usage:**
```bash
# Use pre-downloaded models from NFS mount
EXO_MODELS_PATH=/mnt/nfs/models:/opt/ai-models uv run exo
# Run in offline mode
EXO_OFFLINE=true uv run exo
# Enable image models
EXO_ENABLE_IMAGE_MODELS=true uv run exo
# Use custom namespace for cluster isolation
EXO_LIBP2P_NAMESPACE=my-dev-cluster uv run exo
```
---
### Using the API ### Using the API
exo provides multiple API-compatible interfaces for maximum compatibility with existing tools:
- **OpenAI Chat Completions API** - Compatible with OpenAI clients
- **Claude Messages API** - Compatible with Anthropic's Claude format
- **OpenAI Responses API** - Compatible with OpenAI's Responses format
- **Ollama API** - Compatible with Ollama and tools like OpenWebUI
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance. If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.
--- ---
@@ -366,14 +409,85 @@ When you're done, delete the instance by its ID (find it via `/state` or `/insta
curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
``` ```
### Claude Messages API Compatibility
Use the Claude Messages API format with the `/v1/messages` endpoint:
```bash
curl -N -X POST http://localhost:52415/v1/messages \
-H 'Content-Type: application/json' \
-d '{
"model": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 1024,
"stream": true
}'
```
### OpenAI Responses API Compatibility
Use the OpenAI Responses API format with the `/v1/responses` endpoint:
```bash
curl -N -X POST http://localhost:52415/v1/responses \
-H 'Content-Type: application/json' \
-d '{
"model": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true
}'
```
### Ollama API Compatibility
exo supports Ollama API endpoints for compatibility with tools like OpenWebUI:
```bash
# Ollama chat
curl -X POST http://localhost:52415/ollama/api/chat \
-H 'Content-Type: application/json' \
-d '{
"model": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": false
}'
# List models (Ollama format)
curl http://localhost:52415/ollama/api/tags
```
### Custom Model Loading from HuggingFace
You can add custom models from the HuggingFace hub:
```bash
curl -X POST http://localhost:52415/models/add \
-H 'Content-Type: application/json' \
-d '{
"model_id": "mlx-community/my-custom-model"
}'
```
**Security Note:**
Custom models requiring `trust_remote_code` in their configuration must be explicitly enabled (default is false) for security. Only enable this if you trust the model's remote code execution. Models are fetched from HuggingFace and stored locally as custom model cards.
**Other useful API endpoints*:** **Other useful API endpoints*:**
- List all models: `curl http://localhost:52415/models` - List all models: `curl http://localhost:52415/models`
- List downloaded models only: `curl http://localhost:52415/models?status=downloaded`
- Search HuggingFace: `curl "http://localhost:52415/models/search?query=llama&limit=10"`
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state` - Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
For further details, see: For further details, see:
- API basic documentation in [docs/api.md](docs/api.md). - API documentation in [docs/api.md](docs/api.md).
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py). - API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
--- ---
@@ -432,4 +546,4 @@ On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working
## Contributing ## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo. See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.

View File

@@ -36,8 +36,12 @@
return num.toString(); return num.toString();
} }
// Extract model name from full ID (e.g., "mlx-community/Llama-3.2-1B" -> "Llama-3.2-1B") // Show short name for mlx-community models, full ID for everything else
const modelName = $derived(model.id.split("/").pop() || model.id); const modelName = $derived(
model.author === "mlx-community"
? model.id.split("/").pop() || model.id
: model.id,
);
</script> </script>
<div <div

View File

@@ -507,9 +507,29 @@
}); });
$effect(() => { $effect(() => {
if (containerRef && processedHtml) { if (!containerRef || !browser) return;
setupCopyButtons();
function handleDelegatedClick(event: MouseEvent) {
const codeBtn = (event.target as HTMLElement).closest(
".copy-code-btn",
) as HTMLButtonElement | null;
if (codeBtn) {
handleCopyClick({ currentTarget: codeBtn } as unknown as Event);
return;
}
const mathBtn = (event.target as HTMLElement).closest(
".copy-math-btn",
) as HTMLButtonElement | null;
if (mathBtn) {
handleMathCopyClick({ currentTarget: mathBtn } as unknown as Event);
return;
}
} }
containerRef.addEventListener("click", handleDelegatedClick);
return () => {
containerRef?.removeEventListener("click", handleDelegatedClick);
};
}); });
</script> </script>

View File

@@ -32,6 +32,7 @@
group: ModelGroup; group: ModelGroup;
isExpanded: boolean; isExpanded: boolean;
isFavorite: boolean; isFavorite: boolean;
isHighlighted?: boolean;
selectedModelId: string | null; selectedModelId: string | null;
canModelFit: (id: string) => boolean; canModelFit: (id: string) => boolean;
getModelFitStatus: (id: string) => ModelFitStatus; getModelFitStatus: (id: string) => ModelFitStatus;
@@ -48,6 +49,7 @@
group, group,
isExpanded, isExpanded,
isFavorite, isFavorite,
isHighlighted = false,
selectedModelId, selectedModelId,
canModelFit, canModelFit,
getModelFitStatus, getModelFitStatus,
@@ -150,10 +152,11 @@
</script> </script>
<div <div
data-model-ids={group.variants.map((v) => v.id).join(" ")}
class="border-b border-white/5 last:border-b-0 {!anyVariantFits && class="border-b border-white/5 last:border-b-0 {!anyVariantFits &&
!anyVariantHasInstance !anyVariantHasInstance
? 'opacity-50' ? 'opacity-50'
: ''}" : ''} {isHighlighted ? 'model-just-added' : ''}"
> >
<!-- Main row --> <!-- Main row -->
<div <div
@@ -644,3 +647,21 @@
</div> </div>
{/if} {/if}
</div> </div>
<style>
.model-just-added {
animation: highlightFade 4s ease-out forwards;
}
@keyframes highlightFade {
0%,
40% {
background-color: rgba(20, 83, 45, 0.25);
box-shadow: inset 0 0 0 1px rgba(74, 222, 128, 0.4);
}
100% {
background-color: transparent;
box-shadow: none;
}
}
</style>

View File

@@ -1,4 +1,5 @@
<script lang="ts"> <script lang="ts">
import { tick } from "svelte";
import { fade, fly } from "svelte/transition"; import { fade, fly } from "svelte/transition";
import { cubicOut } from "svelte/easing"; import { cubicOut } from "svelte/easing";
import FamilySidebar from "./FamilySidebar.svelte"; import FamilySidebar from "./FamilySidebar.svelte";
@@ -7,6 +8,7 @@
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte"; import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
import { getNodesWithModelDownloaded } from "$lib/utils/downloads"; import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
import { getRecentEntries } from "$lib/stores/recents.svelte"; import { getRecentEntries } from "$lib/stores/recents.svelte";
import { addToast } from "$lib/stores/toast.svelte";
interface ModelInfo { interface ModelInfo {
id: string; id: string;
@@ -191,6 +193,13 @@
let hfSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null; let hfSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;
let manualModelId = $state(""); let manualModelId = $state("");
let addModelError = $state<string | null>(null); let addModelError = $state<string | null>(null);
let justAddedModelId = $state<string | null>(null);
let justAddedTimer: ReturnType<typeof setTimeout> | null = null;
// Inline HuggingFace search in main search bar
let mainSearchHfResults = $state<HuggingFaceModel[]>([]);
let mainSearchHfLoading = $state(false);
let mainSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;
// Reset transient state when modal opens, but preserve tab selection // Reset transient state when modal opens, but preserve tab selection
$effect(() => { $effect(() => {
@@ -200,6 +209,11 @@
showFilters = false; showFilters = false;
manualModelId = ""; manualModelId = "";
addModelError = null; addModelError = null;
justAddedModelId = null;
if (justAddedTimer) {
clearTimeout(justAddedTimer);
justAddedTimer = null;
}
} }
}); });
@@ -214,6 +228,50 @@
} }
}); });
// Inline HuggingFace search when local search returns no results
$effect(() => {
const query = searchQuery.trim();
const noLocalResults = filteredGroups.length === 0;
if (mainSearchDebounceTimer) {
clearTimeout(mainSearchDebounceTimer);
mainSearchDebounceTimer = null;
}
if (
selectedFamily === "huggingface" ||
selectedFamily === "recents" ||
selectedFamily === "favorites" ||
query.length < 2 ||
!noLocalResults
) {
mainSearchHfResults = [];
mainSearchHfLoading = false;
return;
}
mainSearchHfLoading = true;
mainSearchDebounceTimer = setTimeout(async () => {
try {
const response = await fetch(
`/models/search?query=${encodeURIComponent(query)}&limit=10`,
);
if (response.ok) {
const results: HuggingFaceModel[] = await response.json();
mainSearchHfResults = results.filter(
(r) => !existingModelIds.has(r.id),
);
} else {
mainSearchHfResults = [];
}
} catch {
mainSearchHfResults = [];
} finally {
mainSearchHfLoading = false;
}
}, 500);
});
async function fetchTrendingModels() { async function fetchTrendingModels() {
hfIsLoadingTrending = true; hfIsLoadingTrending = true;
try { try {
@@ -274,6 +332,24 @@
addModelError = null; addModelError = null;
try { try {
await onAddModel(modelId); await onAddModel(modelId);
// Success: show toast, switch to All Models, highlight the model
const shortName = modelId.split("/").pop() || modelId;
addToast({ type: "success", message: `Added ${shortName}` });
justAddedModelId = modelId;
selectedFamily = null;
searchQuery = "";
// Scroll to the newly added model after DOM update
await tick();
const el = document.querySelector(
`[data-model-ids~="${CSS.escape(modelId)}"]`,
);
el?.scrollIntoView({ behavior: "smooth", block: "center" });
// Clear highlight after 4 seconds
if (justAddedTimer) clearTimeout(justAddedTimer);
justAddedTimer = setTimeout(() => {
justAddedModelId = null;
justAddedTimer = null;
}, 4000);
} catch (error) { } catch (error) {
addModelError = addModelError =
error instanceof Error ? error.message : "Failed to add model"; error instanceof Error ? error.message : "Failed to add model";
@@ -841,6 +917,8 @@
{group} {group}
isExpanded={expandedGroups.has(group.id)} isExpanded={expandedGroups.has(group.id)}
isFavorite={favorites.has(group.id)} isFavorite={favorites.has(group.id)}
isHighlighted={justAddedModelId !== null &&
group.variants.some((v) => v.id === justAddedModelId)}
{selectedModelId} {selectedModelId}
{canModelFit} {canModelFit}
{getModelFitStatus} {getModelFitStatus}
@@ -910,6 +988,8 @@
{group} {group}
isExpanded={expandedGroups.has(group.id)} isExpanded={expandedGroups.has(group.id)}
isFavorite={favorites.has(group.id)} isFavorite={favorites.has(group.id)}
isHighlighted={justAddedModelId !== null &&
group.variants.some((v) => v.id === justAddedModelId)}
{selectedModelId} {selectedModelId}
{canModelFit} {canModelFit}
{getModelFitStatus} {getModelFitStatus}
@@ -937,6 +1017,8 @@
{group} {group}
isExpanded={expandedGroups.has(group.id)} isExpanded={expandedGroups.has(group.id)}
isFavorite={favorites.has(group.id)} isFavorite={favorites.has(group.id)}
isHighlighted={justAddedModelId !== null &&
group.variants.some((v) => v.id === justAddedModelId)}
{selectedModelId} {selectedModelId}
{canModelFit} {canModelFit}
{getModelFitStatus} {getModelFitStatus}
@@ -948,6 +1030,55 @@
{instanceStatuses} {instanceStatuses}
/> />
{/each} {/each}
<!-- Inline HuggingFace search results (shown when no local results match) -->
{#if filteredGroups.length === 0 && searchQuery.trim().length >= 2 && selectedFamily !== "huggingface" && selectedFamily !== "recents" && selectedFamily !== "favorites"}
{#if mainSearchHfLoading}
<div
class="flex items-center gap-2 px-3 py-2 border-t border-orange-400/20 bg-orange-950/20"
>
<span
class="w-4 h-4 border-2 border-orange-400 border-t-transparent rounded-full animate-spin"
></span>
<span class="text-xs font-mono text-orange-400/60"
>Searching HuggingFace...</span
>
</div>
{:else if mainSearchHfResults.length > 0}
<div
class="sticky top-0 z-10 flex items-center gap-2 px-3 py-2 bg-orange-950/30 border-y border-orange-400/20 backdrop-blur-sm"
>
<span
class="text-xs font-mono text-orange-400 tracking-wider uppercase"
>From HuggingFace</span
>
</div>
{#each mainSearchHfResults as model}
<HuggingFaceResultItem
{model}
isAdded={existingModelIds.has(model.id)}
isAdding={addingModelId === model.id}
onAdd={() => handleAddModel(model.id)}
onSelect={() => handleSelectHfModel(model.id)}
downloadedOnNodes={downloadsData
? getNodesWithModelDownloaded(downloadsData, model.id).map(
getNodeName,
)
: []}
/>
{/each}
<button
type="button"
class="w-full px-3 py-2 text-xs font-mono text-orange-400/60 hover:text-orange-400 hover:bg-orange-500/10 transition-colors text-center"
onclick={() => {
hfSearchQuery = searchQuery;
searchHuggingFace(searchQuery);
selectedFamily = "huggingface";
}}
>
See all results on Hub
</button>
{/if}
{/if}
{/if} {/if}
</div> </div>
</div> </div>

View File

@@ -4,7 +4,7 @@ This document describes the REST API exposed by the **EXO** service, as implemen
`src/exo/master/api.py` `src/exo/master/api.py`
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface. The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using multiple API-compatible interfaces.
Base URL example: Base URL example:
@@ -144,8 +144,59 @@ Placement result.
Returns the list of available models and their metadata. Returns the list of available models and their metadata.
**Query parameters:**
* `status`: string (optional) - Filter by `downloaded` to show only downloaded models
**Response:** **Response:**
Array of model descriptors. Array of model descriptors including `is_custom` field for custom HuggingFace models.
### Add Custom Model
**POST** `/models/add`
Add a custom model from HuggingFace hub.
**Request body (example):**
```json
{
"model_id": "mlx-community/my-custom-model"
}
```
**Response:**
Model descriptor for the added model.
**Security note:**
Models with `trust_remote_code` enabled in their configuration require explicit opt-in (default is false) for security.
### Delete Custom Model
**DELETE** `/models/custom/{model_id}`
Delete a user-added custom model card.
**Path parameters:**
* `model_id`: string, ID of the custom model to delete
**Response:**
Confirmation JSON with deleted model ID.
### Search Models
**GET** `/models/search`
Search HuggingFace Hub for mlx-community models.
**Query parameters:**
* `query`: string (optional) - Search query
* `limit`: integer (default: 20) - Maximum number of results
**Response:**
Array of HuggingFace model search results.
## 4. Inference / Chat Completions ## 4. Inference / Chat Completions
@@ -168,9 +219,123 @@ Executes a chat completion request using an OpenAI-compatible schema. Supports s
} }
``` ```
**Request parameters:**
* `model`: string, required - Model ID to use
* `messages`: array, required - Conversation messages
* `stream`: boolean (default: false) - Enable streaming responses
* `max_tokens`: integer (optional) - Maximum tokens to generate
* `temperature`: float (optional) - Sampling temperature
* `top_p`: float (optional) - Nucleus sampling parameter
* `top_k`: integer (optional) - Top-k sampling parameter
* `stop`: string or array (optional) - Stop sequences
* `seed`: integer (optional) - Random seed for reproducibility
* `enable_thinking`: boolean (optional) - Enable thinking mode for capable models (DeepSeek V3.1, Qwen3, GLM-4.7)
* `tools`: array (optional) - Tool definitions for function calling
* `logprobs`: boolean (optional) - Return log probabilities
* `top_logprobs`: integer (optional) - Number of top log probabilities to return
**Response:** **Response:**
OpenAI-compatible chat completion response. OpenAI-compatible chat completion response.
**Streaming response format:**
When `stream=true`, returns Server-Sent Events (SSE) with format:
```
data: {"id":"...","object":"chat.completion","created":...,"model":"...","choices":[...]}
data: [DONE]
```
**Non-streaming response includes usage statistics:**
```json
{
"id": "...",
"object": "chat.completion",
"created": 1234567890,
"model": "llama-3.2-1b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you?"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 15,
"completion_tokens": 8,
"total_tokens": 23
}
}
```
**Cancellation:**
You can cancel an active generation by closing the HTTP connection. The server detects the disconnection and stops processing.
### Claude Messages API
**POST** `/v1/messages`
Executes a chat completion request using the Claude Messages API format. Supports streaming and non-streaming modes.
**Request body (example):**
```json
{
"model": "llama-3.2-1b",
"messages": [
{ "role": "user", "content": "Hello" }
],
"max_tokens": 1024,
"stream": false
}
```
**Streaming response format:**
When `stream=true`, returns Server-Sent Events with Claude-specific event types:
* `message_start` - Message generation started
* `content_block_start` - Content block started
* `content_block_delta` - Incremental content chunk
* `content_block_stop` - Content block completed
* `message_delta` - Message metadata updates
* `message_stop` - Message generation completed
**Response:**
Claude-compatible messages response.
### OpenAI Responses API
**POST** `/v1/responses`
Executes a chat completion request using the OpenAI Responses API format. Supports streaming and non-streaming modes.
**Request body (example):**
```json
{
"model": "llama-3.2-1b",
"messages": [
{ "role": "user", "content": "Hello" }
],
"stream": false
}
```
**Streaming response format:**
When `stream=true`, returns Server-Sent Events with response-specific event types:
* `response.created` - Response generation started
* `response.in_progress` - Response is being generated
* `response.output_item.added` - New output item added
* `response.output_item.done` - Output item completed
* `response.done` - Response generation completed
**Response:**
OpenAI Responses API-compatible response.
### Benchmarked Chat Completions ### Benchmarked Chat Completions
**POST** `/bench/chat/completions` **POST** `/bench/chat/completions`
@@ -181,7 +346,13 @@ Same as `/v1/chat/completions`, but also returns performance and generation stat
Same schema as `/v1/chat/completions`. Same schema as `/v1/chat/completions`.
**Response:** **Response:**
Chat completion plus benchmarking metrics. Chat completion plus benchmarking metrics including:
* `prompt_tps` - Tokens per second during prompt processing
* `generation_tps` - Tokens per second during generation
* `prompt_tokens` - Number of prompt tokens
* `generation_tokens` - Number of generated tokens
* `peak_memory_usage` - Peak memory used during generation
### Cancel Command ### Cancel Command
@@ -204,24 +375,148 @@ Cancels an active generation command (text or image). Notifies workers and close
Returns 404 if the command is not found or already completed. Returns 404 if the command is not found or already completed.
## 5. Image Generation & Editing ## 5. Ollama API Compatibility
EXO provides Ollama API compatibility for tools like OpenWebUI.
### Ollama Chat
**POST** `/ollama/api/chat`
**POST** `/ollama/api/api/chat` (alias)
**POST** `/ollama/api/v1/chat` (alias)
Execute a chat request using Ollama API format.
**Request body (example):**
```json
{
"model": "llama-3.2-1b",
"messages": [
{ "role": "user", "content": "Hello" }
],
"stream": false
}
```
**Response:**
Ollama-compatible chat response.
### Ollama Generate
**POST** `/ollama/api/generate`
Execute a text generation request using Ollama API format.
**Request body (example):**
```json
{
"model": "llama-3.2-1b",
"prompt": "Hello",
"stream": false
}
```
**Response:**
Ollama-compatible generation response.
### Ollama Tags
**GET** `/ollama/api/tags`
**GET** `/ollama/api/api/tags` (alias)
**GET** `/ollama/api/v1/tags` (alias)
Returns list of downloaded models in Ollama tags format.
**Response:**
Array of model tags with metadata.
### Ollama Show
**POST** `/ollama/api/show`
Returns model information in Ollama show format.
**Request body:**
```json
{
"name": "llama-3.2-1b"
}
```
**Response:**
Model details including modelfile and family.
### Ollama PS
**GET** `/ollama/api/ps`
Returns list of running models (active instances).
**Response:**
Array of active model instances.
### Ollama Version
**GET** `/ollama/api/version`
**HEAD** `/ollama/` (alias)
**HEAD** `/ollama/api/version` (alias)
Returns version information for Ollama API compatibility.
**Response:**
```json
{
"version": "exo v1.0"
}
```
## 6. Image Generation & Editing
### Image Generation ### Image Generation
**POST** `/v1/images/generations` **POST** `/v1/images/generations`
Executes an image generation request using an OpenAI-compatible schema with additional advanced_params. Executes an image generation request using an OpenAI-compatible schema with additional advanced_params. Supports both streaming and non-streaming modes.
**Request body (example):** **Request body (example):**
```json ```json
{ {
"prompt": "a robot playing chess", "prompt": "a robot playing chess",
"model": "flux-dev", "model": "exolabs/FLUX.1-dev",
"n": 1,
"size": "1024x1024",
"stream": false, "stream": false,
"response_format": "b64_json"
} }
``` ```
**Request parameters:**
* `prompt`: string, required - Text description of the image
* `model`: string, required - Image model ID
* `n`: integer (default: 1) - Number of images to generate
* `size`: string (default: "auto") - Image dimensions. Supported sizes:
- `512x512`
- `768x768`
- `1024x768`
- `768x1024`
- `1024x1024`
- `1024x1536`
- `1536x1024`
- `1024x1365`
- `1365x1024`
* `stream`: boolean (default: false) - Enable streaming for partial images
* `partial_images`: integer (default: 0) - Number of partial images to stream during generation
* `response_format`: string (default: "b64_json") - Either `url` or `b64_json`
* `quality`: string (default: "medium") - Either `high`, `medium`, or `low`
* `output_format`: string (default: "png") - Either `png`, `jpeg`, or `webp`
* `advanced_params`: object (optional) - Advanced generation parameters
**Advanced Parameters (`advanced_params`):** **Advanced Parameters (`advanced_params`):**
| Parameter | Type | Constraints | Description | | Parameter | Type | Constraints | Description |
@@ -231,8 +526,54 @@ Executes an image generation request using an OpenAI-compatible schema with addi
| `guidance` | float | 1.0-20.0 | Classifier-free guidance scale | | `guidance` | float | 1.0-20.0 | Classifier-free guidance scale |
| `negative_prompt` | string | - | Text describing what to avoid in the image | | `negative_prompt` | string | - | Text describing what to avoid in the image |
**Non-streaming response:**
```json
{
"created": 1234567890,
"data": [
{
"b64_json": "iVBORw0KGgoAAAANSUhEUgAA...",
"url": null
}
]
}
```
**Streaming response format:**
When `stream=true` and `partial_images > 0`, returns Server-Sent Events:
```
data: {"type":"partial","image_index":0,"partial_index":1,"total_partials":5,"format":"png","data":{"b64_json":"..."}}
data: {"type":"final","image_index":0,"format":"png","data":{"b64_json":"..."}}
data: [DONE]
```
### Image Editing
**POST** `/v1/images/edits`
Executes an image editing request (img2img) using FLUX.1-Kontext-dev or similar models.
**Request (multipart/form-data):**
* `image`: file, required - Input image to edit
* `prompt`: string, required - Text description of desired changes
* `model`: string, required - Image editing model ID (e.g., `exolabs/FLUX.1-Kontext-dev`)
* `n`: integer (default: 1) - Number of edited images to generate
* `size`: string (optional) - Output image dimensions
* `response_format`: string (default: "b64_json") - Either `url` or `b64_json`
* `input_fidelity`: string (default: "low") - Either `low` or `high` - Controls how closely the output follows the input image
* `stream`: string (default: "false") - Enable streaming
* `partial_images`: string (default: "0") - Number of partial images to stream
* `quality`: string (default: "medium") - Either `high`, `medium`, or `low`
* `output_format`: string (default: "png") - Either `png`, `jpeg`, or `webp`
* `advanced_params`: string (optional) - JSON-encoded advanced parameters
**Response:** **Response:**
OpenAI-compatible image generation response. Same format as `/v1/images/generations`.
### Benchmarked Image Generation ### Benchmarked Image Generation
@@ -244,16 +585,15 @@ Same as `/v1/images/generations`, but also returns generation statistics.
Same schema as `/v1/images/generations`. Same schema as `/v1/images/generations`.
**Response:** **Response:**
Image generation plus benchmarking metrics. Image generation plus benchmarking metrics including:
### Image Editing * `seconds_per_step` - Average time per denoising step
* `total_generation_time` - Total generation time
**POST** `/v1/images/edits` * `num_inference_steps` - Number of inference steps used
* `num_images` - Number of images generated
Executes an image editing request using an OpenAI-compatible schema with additional advanced_params (same as `/v1/images/generations`). * `image_width` - Output image width
* `image_height` - Output image height
**Response:** * `peak_memory_usage` - Peak memory used during generation
Same format as `/v1/images/generations`.
### Benchmarked Image Editing ### Benchmarked Image Editing
@@ -267,37 +607,127 @@ Same schema as `/v1/images/edits`.
**Response:** **Response:**
Same format as `/bench/images/generations`, including `generation_stats`. Same format as `/bench/images/generations`, including `generation_stats`.
## 6. Complete Endpoint Summary ### List Images
**GET** `/images`
List all stored images.
**Response:**
Array of image metadata including URLs and expiration times.
### Get Image
**GET** `/images/{image_id}`
Retrieve a stored image by ID.
**Path parameters:**
* `image_id`: string, ID of the image
**Response:**
Image file with appropriate content type.
## 7. Complete Endpoint Summary
``` ```
# General
GET /node_id GET /node_id
GET /state GET /state
GET /events GET /events
# Instance Management
POST /instance POST /instance
GET /instance/{instance_id} GET /instance/{instance_id}
DELETE /instance/{instance_id} DELETE /instance/{instance_id}
GET /instance/previews GET /instance/previews
GET /instance/placement GET /instance/placement
POST /place_instance POST /place_instance
# Models
GET /models GET /models
GET /v1/models GET /v1/models
POST /models/add
DELETE /models/custom/{model_id}
GET /models/search
# Text Generation (OpenAI Chat Completions)
POST /v1/chat/completions POST /v1/chat/completions
POST /bench/chat/completions POST /bench/chat/completions
# Text Generation (Claude Messages API)
POST /v1/messages
# Text Generation (OpenAI Responses API)
POST /v1/responses
# Text Generation (Ollama API)
POST /ollama/api/chat
POST /ollama/api/api/chat
POST /ollama/api/v1/chat
POST /ollama/api/generate
GET /ollama/api/tags
GET /ollama/api/api/tags
GET /ollama/api/v1/tags
POST /ollama/api/show
GET /ollama/api/ps
GET /ollama/api/version
HEAD /ollama/
HEAD /ollama/api/version
# Command Control
POST /v1/cancel/{command_id} POST /v1/cancel/{command_id}
# Image Generation
POST /v1/images/generations POST /v1/images/generations
POST /bench/images/generations POST /bench/images/generations
POST /v1/images/edits POST /v1/images/edits
POST /bench/images/edits POST /bench/images/edits
GET /images
GET /images/{image_id}
``` ```
## 7. Notes ## 8. Notes
* The `/v1/chat/completions` endpoint is compatible with the OpenAI Chat API format, so existing OpenAI clients can be pointed to EXO by changing the base URL. ### API Compatibility
* The `/v1/images/generations` and `/v1/images/edits` endpoints are compatible with the OpenAI Images API format.
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances. EXO provides multiple API-compatible interfaces:
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.
* **OpenAI Chat Completions API** - Compatible with OpenAI clients and tools
* **Claude Messages API** - Compatible with Anthropic's Claude API format
* **OpenAI Responses API** - Compatible with OpenAI's Responses API format
* **Ollama API** - Compatible with Ollama and tools like OpenWebUI
Existing OpenAI, Claude, or Ollama clients can be pointed to EXO by changing the base URL.
### Custom Models
You can add custom models from HuggingFace using the `/models/add` endpoint. Custom models are identified by the `is_custom` field in model list responses.
**Security:** Models requiring `trust_remote_code` must be explicitly enabled (default is false) for security. Only enable this if you trust the model's remote code.
### Usage Statistics
Chat completion responses include usage statistics with:
* `prompt_tokens` - Number of tokens in the prompt
* `completion_tokens` - Number of tokens generated
* `total_tokens` - Sum of prompt and completion tokens
### Request Cancellation
You can cancel active requests by:
1. Closing the HTTP connection (for streaming requests)
2. Calling `/v1/cancel/{command_id}` (for any request)
The server detects cancellation and stops processing immediately.
### Instance Placement
The instance placement endpoints allow you to plan and preview cluster allocations before creating instances. This helps optimize resource usage across nodes.
### Observability
The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.

View File

@@ -28,6 +28,26 @@ There are currently 5 major systems:
Implements a distributed algorithm for master election in unstable networking conditions Implements a distributed algorithm for master election in unstable networking conditions
## API Layer
The API system uses multiple adapters to support multiple API formats, converting them to a single request / response type.
### Adapter Pattern
Adapters convert between external API formats and EXO's internal types:
```
Chat Completions → [adapter] → TextGenerationTaskParams → Application
Claude Messages → [adapter] → TextGenerationTaskParams → Application
Responses API → [adapter] → TextGenerationTaskParams → Application
Ollama API → [adapter] → TextGenerationTaskParams → Application
```
Each adapter implements two key functions:
1. **Request conversion**: Converts API-specific requests to `TextGenerationTaskParams`
2. **Response generation**: Converts internal `TokenChunk` streams back to API-specific formats (streaming and non-streaming)
## Topics ## Topics
There are currently 5 topics: There are currently 5 topics:

View File

@@ -1,16 +1,20 @@
import asyncio import asyncio
import pytest import pytest
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError from exo_pyo3_bindings import (
Keypair,
NetworkingHandle,
NoPeersSubscribedToTopicError,
PyFromSwarm,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sleep_on_multiple_items() -> None: async def test_sleep_on_multiple_items() -> None:
print("PYTHON: starting handle") print("PYTHON: starting handle")
h = NetworkingHandle(Keypair.generate_ed25519()) h = NetworkingHandle(Keypair.generate())
ct = asyncio.create_task(_await_cons(h)) rt = asyncio.create_task(_await_recv(h))
mt = asyncio.create_task(_await_msg(h))
# sleep for 4 ticks # sleep for 4 ticks
for i in range(4): for i in range(4):
@@ -22,13 +26,11 @@ async def test_sleep_on_multiple_items() -> None:
print("caught it", e) print("caught it", e)
async def _await_cons(h: NetworkingHandle): async def _await_recv(h: NetworkingHandle):
while True: while True:
c = await h.connection_update_recv() event = await h.recv()
print(f"PYTHON: connection update: {c}") match event:
case PyFromSwarm.Connection() as c:
print(f"PYTHON: connection update: {c}")
async def _await_msg(h: NetworkingHandle): case PyFromSwarm.Message() as m:
while True: print(f"PYTHON: message: {m}")
m = await h.gossipsub_recv()
print(f"PYTHON: message: {m}")

View File

@@ -19,9 +19,7 @@ def exo_shard_downloader(
max_parallel_downloads: int = 8, offline: bool = False max_parallel_downloads: int = 8, offline: bool = False
) -> ShardDownloader: ) -> ShardDownloader:
return SingletonShardDownloader( return SingletonShardDownloader(
CachedShardDownloader( ResumableShardDownloader(max_parallel_downloads, offline=offline)
ResumableShardDownloader(max_parallel_downloads, offline=offline)
)
) )
@@ -85,39 +83,6 @@ class SingletonShardDownloader(ShardDownloader):
return await self.shard_downloader.get_shard_download_status_for_shard(shard) return await self.shard_downloader.get_shard_download_status_for_shard(shard)
class CachedShardDownloader(ShardDownloader):
def __init__(self, shard_downloader: ShardDownloader):
self.shard_downloader = shard_downloader
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
) -> None:
self.shard_downloader.on_progress(callback)
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
if (shard.model_card.model_id, shard) in self.cache:
return self.cache[(shard.model_card.model_id, shard)]
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_card.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
async for path, status in self.shard_downloader.get_shard_download_status():
yield path, status
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata
) -> RepoDownloadProgress:
return await self.shard_downloader.get_shard_download_status_for_shard(shard)
class ResumableShardDownloader(ShardDownloader): class ResumableShardDownloader(ShardDownloader):
def __init__(self, max_parallel_downloads: int = 8, offline: bool = False): def __init__(self, max_parallel_downloads: int = 8, offline: bool = False):
self.max_parallel_downloads = max_parallel_downloads self.max_parallel_downloads = max_parallel_downloads

View File

@@ -0,0 +1,211 @@
"""Tests that re-downloading a previously deleted model completes successfully."""
import asyncio
import contextlib
from collections.abc import AsyncIterator, Awaitable
from datetime import timedelta
from pathlib import Path
from typing import Callable
from unittest.mock import AsyncMock, patch
from exo.download.coordinator import DownloadCoordinator
from exo.download.download_utils import RepoDownloadProgress
from exo.download.impl_shard_downloader import SingletonShardDownloader
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.commands import (
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SystemId
from exo.shared.types.events import Event, NodeDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
NODE_ID = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
MODEL_ID = ModelId("test-org/test-model")
def _make_shard(model_id: ModelId = MODEL_ID) -> ShardMetadata:
return PipelineShardMetadata(
model_card=ModelCard(
model_id=model_id,
storage_size=Memory.from_mb(100),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=28,
n_layers=28,
)
class FakeShardDownloader(ShardDownloader):
"""Fake downloader that simulates a successful download by firing the
progress callback with status='complete' when ensure_shard is called."""
def __init__(self) -> None:
self._progress_callbacks: list[
Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]]
] = []
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
) -> None:
self._progress_callbacks.append(callback)
async def ensure_shard(
self,
shard: ShardMetadata,
config_only: bool = False, # noqa: ARG002
) -> Path:
# Simulate a completed download by firing the progress callback
progress = RepoDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_revision="main",
shard=shard,
completed_files=1,
total_files=1,
downloaded=Memory.from_mb(100),
downloaded_this_session=Memory.from_mb(100),
total=Memory.from_mb(100),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
)
for cb in self._progress_callbacks:
await cb(shard, progress)
return Path("/fake/models") / shard.model_card.model_id.normalize()
async def get_shard_download_status(
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
if False: # noqa: SIM108 # empty async generator
yield (
Path(),
RepoDownloadProgress( # pyright: ignore[reportUnreachable]
repo_id="",
repo_revision="",
shard=_make_shard(),
completed_files=0,
total_files=0,
downloaded=Memory.from_bytes(0),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="not_started",
),
)
async def get_shard_download_status_for_shard(
self,
shard: ShardMetadata,
) -> RepoDownloadProgress:
return RepoDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_revision="main",
shard=shard,
completed_files=0,
total_files=1,
downloaded=Memory.from_bytes(0),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_mb(100),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="not_started",
)
async def test_re_download_after_delete_completes() -> None:
"""A model that was downloaded, deleted, and then re-downloaded should
reach DownloadCompleted status. This is an end-to-end test through
the DownloadCoordinator."""
cmd_send: Sender[ForwarderDownloadCommand]
cmd_send, cmd_recv = channel[ForwarderDownloadCommand]()
event_send, event_recv = channel[Event]()
fake_downloader = FakeShardDownloader()
wrapped_downloader = SingletonShardDownloader(fake_downloader)
coordinator = DownloadCoordinator(
node_id=NODE_ID,
shard_downloader=wrapped_downloader,
download_command_receiver=cmd_recv,
event_sender=event_send,
)
shard = _make_shard()
origin = SystemId("test")
with patch("exo.download.coordinator.delete_model", new_callable=AsyncMock):
# Run the coordinator in the background
coordinator_task = asyncio.create_task(coordinator.run())
try:
# 1. Start first download
await cmd_send.send(
ForwarderDownloadCommand(
origin=origin,
command=StartDownload(target_node_id=NODE_ID, shard_metadata=shard),
)
)
# Wait for DownloadCompleted
first_completed = await _wait_for_download_completed(event_recv, MODEL_ID)
assert first_completed is not None, "First download should complete"
# 2. Delete the model
await cmd_send.send(
ForwarderDownloadCommand(
origin=origin,
command=DeleteDownload(target_node_id=NODE_ID, model_id=MODEL_ID),
)
)
# Give the coordinator time to process the delete
await asyncio.sleep(0.05)
# 3. Re-download the same model
await cmd_send.send(
ForwarderDownloadCommand(
origin=origin,
command=StartDownload(target_node_id=NODE_ID, shard_metadata=shard),
)
)
# Wait for second DownloadCompleted — this is the bug: it never arrives
second_completed = await _wait_for_download_completed(event_recv, MODEL_ID)
assert second_completed is not None, (
"Re-download after deletion should complete"
)
finally:
coordinator.shutdown()
coordinator_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await coordinator_task
async def _wait_for_download_completed(
event_recv: Receiver[Event], model_id: ModelId, timeout: float = 2.0
) -> DownloadCompleted | None:
"""Drain events until we see a DownloadCompleted for the given model, or timeout."""
try:
async with asyncio.timeout(timeout):
while True:
event = await event_recv.receive()
if (
isinstance(event, NodeDownloadProgress)
and isinstance(event.download_progress, DownloadCompleted)
and event.download_progress.shard_metadata.model_card.model_id
== model_id
):
return event.download_progress
except TimeoutError:
return None

View File

@@ -26,7 +26,11 @@ from exo.shared.types.chunks import (
ToolCallChunk, ToolCallChunk,
) )
from exo.shared.types.common import CommandId from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams from exo.shared.types.text_generation import (
InputMessage,
TextGenerationTaskParams,
resolve_reasoning_params,
)
def chat_request_to_text_generation( def chat_request_to_text_generation(
@@ -75,6 +79,10 @@ def chat_request_to_text_generation(
dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True) dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True)
chat_template_messages.append(dumped) chat_template_messages.append(dumped)
resolved_effort, resolved_thinking = resolve_reasoning_params(
request.reasoning_effort, request.enable_thinking
)
return TextGenerationTaskParams( return TextGenerationTaskParams(
model=request.model, model=request.model,
input=input_messages input=input_messages
@@ -89,12 +97,15 @@ def chat_request_to_text_generation(
seed=request.seed, seed=request.seed,
stream=request.stream, stream=request.stream,
tools=request.tools, tools=request.tools,
enable_thinking=request.enable_thinking, reasoning_effort=resolved_effort,
enable_thinking=resolved_thinking,
chat_template_messages=chat_template_messages chat_template_messages=chat_template_messages
if chat_template_messages if chat_template_messages
else None, else None,
logprobs=request.logprobs or False, logprobs=request.logprobs or False,
top_logprobs=request.top_logprobs, top_logprobs=request.top_logprobs,
repetition_penalty=request.repetition_penalty,
repetition_context_size=request.repetition_context_size,
) )

View File

@@ -42,7 +42,11 @@ from exo.shared.types.openai_responses import (
ResponseTextDoneEvent, ResponseTextDoneEvent,
ResponseUsage, ResponseUsage,
) )
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams from exo.shared.types.text_generation import (
InputMessage,
TextGenerationTaskParams,
resolve_reasoning_params,
)
def _format_sse(event: ResponsesStreamEvent) -> str: def _format_sse(event: ResponsesStreamEvent) -> str:
@@ -119,6 +123,11 @@ def responses_request_to_text_generation(
) )
built_chat_template = chat_template_messages if chat_template_messages else None built_chat_template = chat_template_messages if chat_template_messages else None
effort_from_reasoning = request.reasoning.effort if request.reasoning else None
resolved_effort, resolved_thinking = resolve_reasoning_params(
effort_from_reasoning, request.enable_thinking
)
return TextGenerationTaskParams( return TextGenerationTaskParams(
model=request.model, model=request.model,
input=input_value, input=input_value,
@@ -132,6 +141,8 @@ def responses_request_to_text_generation(
stop=request.stop, stop=request.stop,
seed=request.seed, seed=request.seed,
chat_template_messages=built_chat_template or request.chat_template_messages, chat_template_messages=built_chat_template or request.chat_template_messages,
reasoning_effort=resolved_effort,
enable_thinking=resolved_thinking,
) )

View File

@@ -3,7 +3,7 @@ import contextlib
import json import json
import random import random
import time import time
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
from datetime import datetime, timezone from datetime import datetime, timezone
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
@@ -775,7 +775,7 @@ class API:
return resolved_model return resolved_model
def stream_events(self) -> StreamingResponse: def stream_events(self) -> StreamingResponse:
def _generate_json_array(events: Iterator[Event]) -> Iterator[str]: def _generate_json_array(events: Iterable[Event]) -> Iterable[str]:
yield "[" yield "["
first = True first = True
for event in events: for event in events:
@@ -1586,26 +1586,42 @@ class API:
async def search_models( async def search_models(
self, query: str = "", limit: int = 20 self, query: str = "", limit: int = 20
) -> list[HuggingFaceSearchResult]: ) -> list[HuggingFaceSearchResult]:
"""Search HuggingFace Hub for mlx-community models.""" """Search HuggingFace Hub — tries mlx-community first, falls back to all of HuggingFace."""
from huggingface_hub import list_models from huggingface_hub import ModelInfo, list_models
results = list_models( def _to_results(models: Iterable[ModelInfo]) -> list[HuggingFaceSearchResult]:
search=query or None, return [
author="mlx-community", HuggingFaceSearchResult(
sort="downloads", id=m.id,
limit=limit, author=m.author or "",
) downloads=m.downloads or 0,
return [ likes=m.likes or 0,
HuggingFaceSearchResult( last_modified=str(m.last_modified or ""),
id=m.id, tags=list(m.tags or []),
author=m.author or "", )
downloads=m.downloads or 0, for m in models
likes=m.likes or 0, ]
last_modified=str(m.last_modified or ""),
tags=list(m.tags or []), # Search mlx-community first
mlx_results = _to_results(
list_models(
search=query or None,
author="mlx-community",
sort="downloads",
limit=limit,
) )
for m in results )
] if mlx_results:
return mlx_results
# Fall back to searching all of HuggingFace
return _to_results(
list_models(
search=query or None,
sort="downloads",
limit=limit,
)
)
async def run(self): async def run(self):
shutdown_ev = anyio.Event() shutdown_ev = anyio.Event()

View File

@@ -328,17 +328,22 @@ class Master:
task_id=task_id, task_id=task_id,
) )
) )
case TaskFinished(): else:
generated_events.append( logger.warning(
TaskDeleted( f"Nonexistent command {command.cancelled_command_id} cancelled"
task_id=self.command_task_mapping[
command.finished_command_id
]
) )
) case TaskFinished():
self.command_task_mapping.pop( if (
command.finished_command_id, None task_id := self.command_task_mapping.pop(
) command.finished_command_id, None
)
) is not None:
generated_events.append(TaskDeleted(task_id=task_id))
else:
logger.warning(
f"Finished command {command.finished_command_id} finished"
)
case RequestEventLog(): case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages # We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time # rate limit to 1000 at a time

View File

@@ -258,6 +258,6 @@ def get_node_id_keypair(
# if no valid credentials, create new ones and persist # if no valid credentials, create new ones and persist
with open(path, "w+b") as f: with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519() keypair = Keypair.generate()
f.write(keypair.to_bytes()) f.write(keypair.to_bytes())
return keypair return keypair

View File

@@ -0,0 +1,30 @@
import pytest
from exo.shared.types.text_generation import resolve_reasoning_params
def test_both_none_returns_none_none() -> None:
assert resolve_reasoning_params(None, None) == (None, None)
def test_both_set_passes_through_unchanged() -> None:
assert resolve_reasoning_params("high", True) == ("high", True)
assert resolve_reasoning_params("none", True) == ("none", True)
assert resolve_reasoning_params("low", False) == ("low", False)
def test_enable_thinking_true_derives_medium() -> None:
assert resolve_reasoning_params(None, True) == ("medium", True)
def test_enable_thinking_false_derives_none() -> None:
assert resolve_reasoning_params(None, False) == ("none", False)
def test_reasoning_effort_none_derives_thinking_false() -> None:
assert resolve_reasoning_params("none", None) == ("none", False)
@pytest.mark.parametrize("effort", ["minimal", "low", "medium", "high", "xhigh"])
def test_non_none_effort_derives_thinking_true(effort: str) -> None:
assert resolve_reasoning_params(effort, None) == (effort, True) # pyright: ignore[reportArgumentType]

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, field_validator
from exo.shared.models.model_cards import ModelCard, ModelId from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.shared.types.text_generation import ReasoningEffort
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import CamelCaseModel
@@ -198,7 +199,10 @@ class ChatCompletionRequest(BaseModel):
top_p: float | None = None top_p: float | None = None
top_k: int | None = None top_k: int | None = None
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
reasoning_effort: ReasoningEffort | None = None
enable_thinking: bool | None = None enable_thinking: bool | None = None
repetition_penalty: float | None = None
repetition_context_size: int | None = None
tool_choice: str | dict[str, Any] | None = None tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None parallel_tool_calls: bool | None = None
user: str | None = None user: str | None = None

View File

@@ -12,6 +12,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from exo.shared.types.common import ModelId from exo.shared.types.common import ModelId
from exo.shared.types.text_generation import ReasoningEffort
# Type aliases # Type aliases
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"] ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
@@ -71,6 +72,13 @@ ResponseInputItem = (
) )
class Reasoning(BaseModel, frozen=True):
"""Reasoning configuration for OpenAI Responses API."""
effort: ReasoningEffort | None = None
summary: Literal["auto", "concise", "detailed"] | None = None
class ResponsesRequest(BaseModel, frozen=True): class ResponsesRequest(BaseModel, frozen=True):
"""Request body for OpenAI Responses API. """Request body for OpenAI Responses API.
@@ -89,8 +97,15 @@ class ResponsesRequest(BaseModel, frozen=True):
stream: bool = False stream: bool = False
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
metadata: dict[str, str] | None = None metadata: dict[str, str] | None = None
reasoning: Reasoning | None = None
# --- exo extensions (not in OpenAI Responses API spec) --- # --- exo extensions (not in OpenAI Responses API spec) ---
enable_thinking: bool | None = Field(
default=None,
description="[exo extension] Boolean thinking toggle. Not part of the OpenAI Responses API.",
json_schema_extra={"x-exo-extension": True},
)
top_k: int | None = Field( top_k: int | None = Field(
default=None, default=None,
description="[exo extension] Top-k sampling parameter. Not part of the OpenAI Responses API.", description="[exo extension] Top-k sampling parameter. Not part of the OpenAI Responses API.",

View File

@@ -11,6 +11,29 @@ from pydantic import BaseModel
from exo.shared.types.common import ModelId from exo.shared.types.common import ModelId
MessageRole = Literal["user", "assistant", "system", "developer"] MessageRole = Literal["user", "assistant", "system", "developer"]
ReasoningEffort = Literal["none", "minimal", "low", "medium", "high", "xhigh"]
def resolve_reasoning_params(
reasoning_effort: ReasoningEffort | None,
enable_thinking: bool | None,
) -> tuple[ReasoningEffort | None, bool | None]:
"""
enable_thinking=True -> reasoning_effort="medium"
enable_thinking=False -> reasoning_effort="none"
reasoning_effort="none" -> enable_thinking=False
reasoning_effort=<anything else> -> enable_thinking=True
"""
resolved_effort: ReasoningEffort | None = reasoning_effort
resolved_thinking: bool | None = enable_thinking
if reasoning_effort is None and enable_thinking is not None:
resolved_effort = "medium" if enable_thinking else "none"
if enable_thinking is None and reasoning_effort is not None:
resolved_thinking = reasoning_effort != "none"
return resolved_effort, resolved_thinking
class InputMessage(BaseModel, frozen=True): class InputMessage(BaseModel, frozen=True):
@@ -40,6 +63,9 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
stop: str | list[str] | None = None stop: str | list[str] | None = None
seed: int | None = None seed: int | None = None
chat_template_messages: list[dict[str, Any]] | None = None chat_template_messages: list[dict[str, Any]] | None = None
reasoning_effort: ReasoningEffort | None = None
enable_thinking: bool | None = None enable_thinking: bool | None = None
logprobs: bool = False logprobs: bool = False
top_logprobs: int | None = None top_logprobs: int | None = None
repetition_penalty: float | None = None
repetition_context_size: int | None = None

View File

@@ -10,7 +10,7 @@ from mlx_lm.generate import (
stream_generate, stream_generate,
) )
from mlx_lm.models.cache import ArraysCache, RotatingKVCache from mlx_lm.models.cache import ArraysCache, RotatingKVCache
from mlx_lm.sample_utils import make_sampler from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.api import ( from exo.shared.types.api import (
@@ -437,6 +437,7 @@ def mlx_generate(
group: mx.distributed.Group | None, group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None, on_prefill_progress: Callable[[int, int], None] | None = None,
distributed_prompt_progress_callback: Callable[[], None] | None = None, distributed_prompt_progress_callback: Callable[[], None] | None = None,
on_generation_token: Callable[[], None] | None = None,
) -> Generator[GenerationResponse]: ) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation # Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory() mx.reset_peak_memory()
@@ -469,11 +470,16 @@ def mlx_generate(
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)" f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
) )
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = [] logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = (
make_logits_processors(
repetition_penalty=task.repetition_penalty,
repetition_context_size=task.repetition_context_size,
)
)
if is_bench: if is_bench:
# Only sample length eos tokens # Only sample length eos tokens
eos_ids = eos_ids_from_tokenizer(tokenizer) eos_ids = eos_ids_from_tokenizer(tokenizer)
logits_processors = [ban_token_ids(eos_ids)] logits_processors = [ban_token_ids(eos_ids)] + logits_processors
sampler = make_sampler( sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7, temp=task.temperature if task.temperature is not None else 0.7,
@@ -644,6 +650,9 @@ def mlx_generate(
full_prompt_tokens, caches, cache_snapshots full_prompt_tokens, caches, cache_snapshots
) )
if on_generation_token is not None:
on_generation_token()
yield GenerationResponse( yield GenerationResponse(
text=text, text=text,
token=out.token, token=out.token,

View File

@@ -554,6 +554,8 @@ def apply_chat_template(
# Jinja ignores unknown variables, so passing both is safe. # Jinja ignores unknown variables, so passing both is safe.
extra_kwargs["enable_thinking"] = task_params.enable_thinking extra_kwargs["enable_thinking"] = task_params.enable_thinking
extra_kwargs["thinking"] = task_params.enable_thinking extra_kwargs["thinking"] = task_params.enable_thinking
if task_params.reasoning_effort is not None:
extra_kwargs["reasoning_effort"] = task_params.reasoning_effort
patched_template: str | None = None patched_template: str | None = None
if task_params.tools: if task_params.tools:

View File

@@ -297,10 +297,10 @@ def _pending_tasks(
# the task status _should_ be set to completed by the LAST runner # the task status _should_ be set to completed by the LAST runner
# it is currently set by the first # it is currently set by the first
# this is definitely a hack # this is definitely a hack
if task.task_id in runner.completed: if task.task_id in runner.completed or task.task_id in runner.in_progress:
continue continue
if isinstance(runner.status, RunnerReady) and all( if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning)) isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
): ):

View File

@@ -32,11 +32,19 @@ def entrypoint(
# Import main after setting global logger - this lets us just import logger from this module # Import main after setting global logger - this lets us just import logger from this module
try: try:
if bound_instance.is_image_model: if bound_instance.is_image_model:
from exo.worker.runner.image_models.runner import main from exo.worker.runner.image_models.runner import Runner as ImageRunner
else:
from exo.worker.runner.llm_inference.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver) runner = ImageRunner(
bound_instance, event_sender, task_receiver, cancel_receiver
)
runner.main()
else:
from exo.worker.runner.llm_inference.runner import Runner
runner = Runner(
bound_instance, event_sender, task_receiver, cancel_receiver
)
runner.main()
except ClosedResourceError: except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly") logger.warning("Runner communication closed unexpectedly")

View File

@@ -182,272 +182,266 @@ def _send_image_chunk(
) )
def main( class Runner:
bound_instance: BoundInstance, def __init__(
event_sender: MpSender[Event], self,
task_receiver: MpReceiver[Task], bound_instance: BoundInstance,
cancel_receiver: MpReceiver[TaskId], event_sender: MpSender[Event],
): task_receiver: MpReceiver[Task],
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) cancel_receiver: MpReceiver[TaskId],
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard)) ):
self.event_sender = event_sender
self.task_receiver = task_receiver
self.cancel_receiver = cancel_receiver
self.bound_instance = bound_instance
instance, runner_id, shard_metadata = ( soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
bound_instance.instance, resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
device_rank = shard_metadata.device_rank
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
if timeout := getattr(shard_metadata, "should_timeout", 0):
time.sleep(timeout)
setup_start_time = time.time() self.instance, self.runner_id, self.shard_metadata = (
cancelled_tasks = set[TaskId]() bound_instance.instance,
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
self.device_rank = self.shard_metadata.device_rank
image_model: DistributedImageModel | None = None logger.info("hello from the runner")
group = None if getattr(self.shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
if timeout := getattr(self.shard_metadata, "should_timeout", 0):
time.sleep(timeout)
current_status: RunnerStatus = RunnerIdle() self.setup_start_time = time.time()
logger.info("runner created") self.cancelled_tasks = set[TaskId]()
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) self.image_model: DistributedImageModel | None = None
) self.group = None
seen = set[TaskId]()
with task_receiver as tasks: self.current_status: RunnerStatus = RunnerIdle()
for task in tasks: logger.info("runner created")
if task.task_id in seen: self.update_status(RunnerIdle())
logger.warning("repeat task - potential error") self.seen = set[TaskId]()
seen.add(task.task_id)
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK")) def update_status(self, status: RunnerStatus):
event_sender.send( self.current_status = status
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running) self.event_sender.send(
RunnerStatusUpdated(
runner_id=self.runner_id, runner_status=self.current_status
) )
match task: )
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
group = initialize_mlx(bound_instance)
logger.info("runner connected") def send_task_status(self, task: Task, status: TaskStatus):
current_status = RunnerConnected() self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=status)
)
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to def acknowledge_task(self, task: Task):
case LoadModel() if ( self.event_sender.send(TaskAcknowledged(task_id=task.task_id))
isinstance(current_status, RunnerConnected) and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
assert ( def main(self):
ModelTask.TextToImage in shard_metadata.model_card.tasks with self.task_receiver as tasks:
or ModelTask.ImageToImage in shard_metadata.model_card.tasks for task in tasks:
), f"Incorrect model task(s): {shard_metadata.model_card.tasks}" if task.task_id in self.seen:
logger.warning("repeat task - potential error")
image_model = initialize_image_model(bound_instance) self.seen.add(task.task_id)
current_status = RunnerLoaded() self.cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
logger.info("runner loaded") self.send_task_status(task, TaskStatus.Running)
self.handle_task(task)
case StartWarmup() if isinstance(current_status, RunnerLoaded): was_cancelled = (task.task_id in self.cancelled_tasks) or (
current_status = RunnerWarmingUp() TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
logger.info("runner warming up")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
logger.info(f"warming up inference for instance: {instance}")
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
logger.info("warmup completed (non-primary node)")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert image_model
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
is_primary_output = _is_primary_output_node(shard_metadata)
if is_primary_output:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if _is_primary_output_node(shard_metadata):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
if _is_primary_output_node(shard_metadata):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
except Exception as e:
if _is_primary_output_node(shard_metadata):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
if not TYPE_CHECKING:
del image_model, group
mx.clear_cache()
import gc
gc.collect()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
was_cancelled = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if not was_cancelled:
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
) )
event_sender.send( if not was_cancelled:
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) self.send_task_status(task, TaskStatus.Complete)
) self.update_status(self.current_status)
if isinstance(current_status, RunnerShutdown): if isinstance(self.current_status, RunnerShutdown):
break break
def handle_task(self, task: Task):
match task:
case ConnectToGroup() if isinstance(
self.current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
self.update_status(RunnerConnecting())
self.acknowledge_task(task)
self.group = initialize_mlx(self.bound_instance)
logger.info("runner connected")
self.current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(self.current_status, RunnerConnected)
and self.group is not None
) or (isinstance(self.current_status, RunnerIdle) and self.group is None):
logger.info("runner loading")
self.update_status(RunnerLoading())
self.acknowledge_task(task)
assert (
ModelTask.TextToImage in self.shard_metadata.model_card.tasks
or ModelTask.ImageToImage in self.shard_metadata.model_card.tasks
), f"Incorrect model task(s): {self.shard_metadata.model_card.tasks}"
self.image_model = initialize_image_model(self.bound_instance)
self.current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(self.current_status, RunnerLoaded):
logger.info("runner warming up")
self.update_status(RunnerWarmingUp())
self.acknowledge_task(task)
logger.info(f"warming up inference for instance: {self.instance}")
assert self.image_model
image = warmup_image_generator(model=self.image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
logger.info("warmup completed (non-primary node)")
logger.info(
f"runner initialized in {time.time() - self.setup_start_time} seconds"
)
self.current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(task_params=task_params, command_id=command_id) if (
isinstance(self.current_status, RunnerReady)
):
assert self.image_model
logger.info(f"received image generation request: {str(task)[:500]}")
logger.info("runner running")
self.update_status(RunnerRunning())
self.acknowledge_task(task)
try:
image_index = 0
for response in generate_image(
model=self.image_model, task=task_params
):
is_primary_output = _is_primary_output_node(self.shard_metadata)
if is_primary_output:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if _is_primary_output_node(self.shard_metadata):
self.event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=self.shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(
self.event_sender, task.task_id, self.device_rank
)
self.current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(self.current_status, RunnerReady)
):
assert self.image_model
logger.info(f"received image edits request: {str(task)[:500]}")
logger.info("runner running")
self.update_status(RunnerRunning())
self.acknowledge_task(task)
try:
image_index = 0
for response in generate_image(
model=self.image_model, task=task_params
):
if _is_primary_output_node(self.shard_metadata):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
image_index += 1
except Exception as e:
if _is_primary_output_node(self.shard_metadata):
self.event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=self.shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(
self.event_sender, task.task_id, self.device_rank
)
self.current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
logger.info("runner shutting down")
if not TYPE_CHECKING:
del self.image_model, self.group
mx.clear_cache()
import gc
gc.collect()
self.update_status(RunnerShuttingDown())
self.acknowledge_task(task)
self.current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {self.current_status=}"
)

View File

@@ -0,0 +1,346 @@
import itertools
import math
import time
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Generator, Iterable
from dataclasses import dataclass, field
from typing import cast
import mlx.core as mx
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
from exo.shared.types.common import ModelId
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.mlx import Model
from exo.shared.types.tasks import TaskId, TextGeneration
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import (
PrefillCancelled,
mlx_generate,
warmup_inference,
)
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
mx_any,
)
from exo.worker.runner.bootstrap import logger
from .model_output_parsers import apply_all_parsers
from .tool_parsers import ToolParser
class Cancelled:
pass
class Finished:
pass
class GeneratorQueue[T]:
def __init__(self):
self._q = deque[T]()
def push(self, t: T):
self._q.append(t)
def gen(self) -> Generator[T | None]:
while True:
if len(self._q) == 0:
yield None
else:
yield self._q.popleft()
class InferenceGenerator(ABC):
@abstractmethod
def warmup(self) -> None: ...
@abstractmethod
def submit(
self,
task: TextGeneration,
) -> None: ...
@abstractmethod
def step(
self,
) -> Iterable[
tuple[TaskId, ToolCallResponse | GenerationResponse | Cancelled | Finished]
]: ...
@abstractmethod
def close(self) -> None: ...
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
"""Check for debug prompt triggers in the input."""
from exo.worker.engines.mlx.utils_mlx import mlx_force_oom
if len(task_params.input) == 0:
return
prompt = task_params.input[0].content
if not prompt:
return
if EXO_RUNNER_MUST_FAIL in prompt:
raise Exception("Artificial runner exception - for testing purposes only.")
if EXO_RUNNER_MUST_OOM in prompt:
mlx_force_oom()
if EXO_RUNNER_MUST_TIMEOUT in prompt:
time.sleep(100)
@dataclass(eq=False)
class SequentialGenerator(InferenceGenerator):
model: Model
tokenizer: TokenizerWrapper
group: mx.distributed.Group | None
kv_prefix_cache: KVPrefixCache | None
tool_parser: ToolParser | None
model_id: ModelId
device_rank: int
cancel_receiver: MpReceiver[TaskId]
event_sender: MpSender[Event]
check_for_cancel_every: int = 50
_cancelled_tasks: set[TaskId] = field(default_factory=set, init=False)
_maybe_queue: list[TextGeneration] = field(default_factory=list, init=False)
_queue: deque[TextGeneration] = field(default_factory=deque, init=False)
_active: (
tuple[
TextGeneration,
# mlx generator that does work
Generator[GenerationResponse],
# queue that the 1st generator should push to and 3rd generator should pull from
GeneratorQueue[GenerationResponse],
# generator to get parsed outputs
Generator[GenerationResponse | ToolCallResponse | None],
]
| None
) = field(default=None, init=False)
def warmup(self):
logger.info(f"warming up inference for instance: {self.model_id}")
t = time.monotonic()
toks = warmup_inference(
model=self.model,
tokenizer=self.tokenizer,
group=self.group,
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
if self.group is not None:
self.check_for_cancel_every = int(
mx.max(
mx.distributed.all_gather(
mx.array([check_for_cancel_every]),
group=self.group,
)
).item()
)
logger.info(
f"runner checking for cancellation every {check_for_cancel_every} tokens"
)
def submit(
self,
task: TextGeneration,
) -> None:
self._cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
self._maybe_queue.append(task)
def agree_on_tasks(self) -> None:
"""Agree between all ranks about the task ordering (some may have received in different order or not at all)."""
agreed, different = mx_all_gather_tasks(self._maybe_queue, self.group)
self._queue.extend(task for task in self._maybe_queue if task in agreed)
self._maybe_queue = [task for task in self._maybe_queue if task in different]
def step(
self,
) -> Iterable[
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
]:
if self._active is None:
self.agree_on_tasks()
if self._queue:
self._start_next()
else:
return map(lambda task: (task, Cancelled()), self._cancelled_tasks)
assert self._active is not None
task, mlx_gen, queue, output_generator = self._active
response = None
try:
queue.push(next(mlx_gen))
response = next(output_generator)
except (StopIteration, PrefillCancelled):
response = Finished()
self._active = None
if self._queue:
self._start_next()
except Exception as e:
self._send_error(task, e)
self._active = None
raise
return itertools.chain(
[] if response is None else [(task.task_id, response)],
map(lambda task: (task, Cancelled()), self._cancelled_tasks),
)
def _start_next(self) -> None:
task = self._queue.popleft()
try:
mlx_gen = self._build_generator(task)
except Exception as e:
self._send_error(task, e)
raise
queue = GeneratorQueue[GenerationResponse]()
output_generator = apply_all_parsers(
queue.gen(),
apply_chat_template(self.tokenizer, task.task_params),
self.tool_parser,
self.tokenizer,
type(self.model),
self.model_id,
task.task_params.tools,
)
self._active = (task, mlx_gen, queue, output_generator)
def _send_error(self, task: TextGeneration, e: Exception) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=ErrorChunk(
model=self.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]:
_check_for_debug_prompts(task.task_params)
prompt = apply_chat_template(self.tokenizer, task.task_params)
def on_prefill_progress(processed: int, total: int) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=PrefillProgressChunk(
model=self.model_id,
processed_tokens=processed,
total_tokens=total,
),
)
)
def distributed_prompt_progress_callback() -> None:
self._cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self._cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self._cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
self.agree_on_tasks()
tokens_since_cancel_check = self.check_for_cancel_every
def on_generation_token() -> None:
nonlocal tokens_since_cancel_check
tokens_since_cancel_check += 1
if tokens_since_cancel_check >= self.check_for_cancel_every:
tokens_since_cancel_check = 0
self._cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self._cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self._cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
self.agree_on_tasks()
return mlx_generate(
model=self.model,
tokenizer=self.tokenizer,
task=task.task_params,
prompt=prompt,
kv_prefix_cache=self.kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
on_generation_token=on_generation_token,
group=self.group,
)
def close(self) -> None:
del self.model, self.tokenizer, self.group
def mx_all_gather_tasks(
tasks: list[TextGeneration],
group: mx.distributed.Group | None,
) -> tuple[list[TextGeneration], list[TextGeneration]]:
def encode_task_id(task_id: TaskId) -> list[int]:
utf8_task_id = task_id.encode()
return [
int.from_bytes(utf8_task_id[i : i + 1]) for i in range(len(utf8_task_id))
]
def decode_task_id(encoded_task_id: list[int]) -> TaskId:
return TaskId(
bytes.decode(b"".join((x).to_bytes(length=1) for x in encoded_task_id))
)
uuid_byte_length = 36
n_tasks = len(tasks)
all_counts = cast(
list[int],
mx.distributed.all_gather(mx.array([n_tasks]), group=group).tolist(),
)
max_tasks = max(all_counts)
world_size: int = 1 if group is None else group.size()
if max_tasks == 0:
return [], []
padded = [encode_task_id(task.task_id) for task in tasks] + [
[0] * uuid_byte_length
] * (max_tasks - n_tasks)
gathered = cast(
list[list[list[int]]],
mx.distributed.all_gather(mx.array(padded), group=group)
.reshape(world_size, max_tasks, -1)
.tolist(),
)
all_task_ids: list[list[TaskId]] = [
[decode_task_id(encoded_task_id) for encoded_task_id in rank_tasks[:count]]
for rank_tasks, count in zip(gathered, all_counts, strict=True)
]
agreed_ids: set[TaskId] = set(all_task_ids[0])
for rank_tasks in all_task_ids[1:]:
agreed_ids &= set(rank_tasks)
local_tasks = {task.task_id: task for task in tasks}
agreed = [local_tasks[tid] for tid in sorted(agreed_ids)]
different = [task for task in tasks if task.task_id not in agreed_ids]
return agreed, different

View File

@@ -0,0 +1,378 @@
from collections.abc import Generator
from functools import cache
from typing import Any
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ToolCallItem
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.engines.mlx.utils_mlx import (
detect_thinking_prompt_suffix,
)
from exo.worker.runner.bootstrap import logger
from .tool_parsers import ToolParser
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def apply_all_parsers(
receiver: Generator[GenerationResponse | None],
prompt: str,
tool_parser: ToolParser | None,
tokenizer: TokenizerWrapper,
model_type: type[Model],
model_id: ModelId,
tools: list[dict[str, Any]] | None,
) -> Generator[GenerationResponse | ToolCallResponse | None]:
mlx_generator = receiver
if tokenizer.has_thinking:
mlx_generator = parse_thinking_models(
mlx_generator,
tokenizer,
starts_in_thinking=detect_thinking_prompt_suffix(prompt, tokenizer),
)
if issubclass(model_type, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
elif (
issubclass(model_type, DeepseekV32Model)
and "deepseek" in model_id.normalize().lower()
):
mlx_generator = parse_deepseek_v32(mlx_generator)
elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser, tools)
return mlx_generator
def parse_gpt_oss(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
current_tool_name: str | None = None
tool_arg_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
# Debug: log every token with state
logger.debug(
f"parse_gpt_oss token={response.token} text={response.text!r} "
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
f"state={stream.state} current_tool={current_tool_name!r}"
)
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
logger.info(
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
)
yield ToolCallResponse(
tool_calls=[
ToolCallItem(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
],
usage=response.usage,
)
tool_arg_parts = []
current_tool_name = recipient
# If inside a tool call, accumulate arguments
if current_tool_name is not None:
if delta:
tool_arg_parts.append(delta)
continue
if ch == "analysis" and not thinking:
thinking = True
if ch != "analysis" and thinking:
thinking = False
if delta:
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
if response.finish_reason is not None:
yield response
def parse_deepseek_v32(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
Uses accumulated-text matching (not per-token marker checks) because
DSML markers like <DSMLfunction_calls> may span multiple tokens.
Also handles <think>...</think> blocks for thinking mode.
"""
from exo.worker.engines.mlx.dsml_encoding import (
THINKING_END,
THINKING_START,
TOOL_CALLS_END,
TOOL_CALLS_START,
parse_dsml_output,
)
accumulated = ""
in_tool_call = False
thinking = False
# Tokens buffered while we detect the start of a DSML block
pending_buffer: list[GenerationResponse] = []
# Text accumulated during a tool call block
tool_call_text = ""
for response in responses:
if response is None:
yield None
continue
# ── Handle thinking tags ──
if not thinking and THINKING_START in response.text:
thinking = True
# Yield any text before the <think> tag
before = response.text[: response.text.index(THINKING_START)]
if before:
yield response.model_copy(update={"text": before})
continue
if thinking and THINKING_END in response.text:
thinking = False
# Yield any text after the </think> tag
after = response.text[
response.text.index(THINKING_END) + len(THINKING_END) :
]
if after:
yield response.model_copy(update={"text": after, "is_thinking": False})
continue
if thinking:
yield response.model_copy(update={"is_thinking": True})
continue
# ── Handle tool call accumulation ──
if in_tool_call:
tool_call_text += response.text
if TOOL_CALLS_END in tool_call_text:
# Parse the accumulated DSML block
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# EOS reached before end marker — yield buffered text as-is
if response.finish_reason is not None:
logger.info("DSML tool call parsing interrupted by EOS")
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# ── Detect start of tool call block ──
accumulated += response.text
if TOOL_CALLS_START in accumulated:
# The start marker might be split across pending_buffer + current token
start_idx = accumulated.index(TOOL_CALLS_START)
# Yield any pending tokens that are purely before the marker
pre_text = accumulated[:start_idx]
if pre_text:
# Flush pending buffer tokens that contributed text before the marker
for buf_resp in pending_buffer:
if pre_text:
chunk = buf_resp.text
if len(chunk) <= len(pre_text):
yield buf_resp
pre_text = pre_text[len(chunk) :]
else:
yield buf_resp.model_copy(update={"text": pre_text})
pre_text = ""
pending_buffer = []
tool_call_text = accumulated[start_idx:]
accumulated = ""
# Check if the end marker is already present (entire tool call in one token)
if TOOL_CALLS_END in tool_call_text:
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
tool_call_text = ""
else:
in_tool_call = True
continue
# Check if accumulated text might be the start of a DSML marker
# Buffer tokens if we see a partial match at the end
if _could_be_dsml_prefix(accumulated):
pending_buffer.append(response)
continue
# No partial match — flush all pending tokens and the current one
for buf_resp in pending_buffer:
yield buf_resp
pending_buffer = []
accumulated = ""
yield response
# Flush any remaining pending buffer at generator end
for buf_resp in pending_buffer:
yield buf_resp
def _could_be_dsml_prefix(text: str) -> bool:
"""Check if the end of text could be the start of a DSML function_calls marker.
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
This allows us to buffer tokens until we can determine if a tool call is starting.
"""
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
# Only check the last portion of text that could overlap with the marker
max_check = len(TOOL_CALLS_START)
tail = text[-max_check:] if len(text) > max_check else text
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
for i in range(len(tail)):
suffix = tail[i:]
if TOOL_CALLS_START.startswith(suffix):
return True
return False
def parse_thinking_models(
responses: Generator[GenerationResponse | None],
tokenizer: TokenizerWrapper,
starts_in_thinking: bool = True,
) -> Generator[GenerationResponse | None]:
"""Route thinking tokens via is_thinking flag.
Swallows think tag tokens, sets is_thinking on all others.
Always yields tokens with finish_reason to avoid hanging the chunk stream.
"""
in_thinking = starts_in_thinking
for response in responses:
if response is None:
yield None
continue
is_think_tag = (
tokenizer.think_end is not None and response.text == tokenizer.think_end
) or (
tokenizer.think_start is not None and response.text == tokenizer.think_start
)
if is_think_tag:
in_thinking = response.text != tokenizer.think_end
# Never swallow finish_reason — the chunk stream needs it to terminate.
if response.finish_reason is not None:
yield response.model_copy(update={"text": "", "is_thinking": False})
continue
yield response.model_copy(update={"is_thinking": in_thinking})
def parse_tool_calls(
responses: Generator[GenerationResponse | None],
tool_parser: ToolParser,
tools: list[dict[str, Any]] | None,
) -> Generator[GenerationResponse | ToolCallResponse | None]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
if not in_tool_call and response.text.startswith(tool_parser.start_parsing):
in_tool_call = True
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse("".join(tool_call_text_parts).strip(), tools)
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"tool call parsing interrupted, yield partial tool call as text"
)
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
)
yield response
else:
# fallthrough
yield response

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,5 @@
import json import json
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable from typing import Any, Callable
@@ -9,7 +10,177 @@ from exo.shared.types.api import ToolCallItem
class ToolParser: class ToolParser:
start_parsing: str start_parsing: str
end_parsing: str end_parsing: str
parse_tool_calls: Callable[[str], list[ToolCallItem] | None] _inner_parser: Callable[[str], list[ToolCallItem] | None]
def parse(
self, text: str, tools: list[dict[str, Any]] | None
) -> list[ToolCallItem] | None:
parsed = self._inner_parser(text)
if parsed is None:
return None
if tools is not None:
parsed = _coerce_tool_calls_to_schema(parsed, tools)
return parsed
def _json_type_matches(value: Any, expected_type: str) -> bool: # pyright: ignore[reportAny]
if expected_type == "object":
return isinstance(value, dict)
if expected_type == "array":
return isinstance(value, list)
if expected_type == "string":
return isinstance(value, str)
if expected_type == "integer":
return isinstance(value, int) and not isinstance(value, bool)
if expected_type == "number":
return (isinstance(value, int) and not isinstance(value, bool)) or isinstance(
value, float
)
if expected_type == "boolean":
return isinstance(value, bool)
if expected_type == "null":
return value is None
return False
def _coerce_tool_arg_with_schema(value: Any, schema: dict[str, Any]) -> Any: # pyright: ignore[reportAny]
schema_type = schema.get("type")
if isinstance(schema_type, list):
for candidate in schema_type: # pyright: ignore[reportUnknownVariableType]
if not isinstance(candidate, str):
continue
if candidate == "null" and value is None:
return None
candidate_schema = {**schema, "type": candidate}
coerced = _coerce_tool_arg_with_schema(value, candidate_schema) # pyright: ignore[reportAny]
if _json_type_matches(coerced, candidate):
return coerced # pyright: ignore[reportAny]
return value # pyright: ignore[reportAny]
if not isinstance(schema_type, str):
return value # pyright: ignore[reportAny]
if schema_type == "object":
parsed = value # pyright: ignore[reportAny]
if isinstance(parsed, str):
try:
parsed = json.loads(parsed) # pyright: ignore[reportAny]
except Exception:
return value # pyright: ignore[reportAny]
if not isinstance(parsed, dict):
return value # pyright: ignore[reportAny]
properties = schema.get("properties")
if not isinstance(properties, dict):
return parsed # pyright: ignore[reportUnknownVariableType]
return {
key: (
_coerce_tool_arg_with_schema(prop_value, prop_schema) # pyright: ignore[reportUnknownArgumentType]
if isinstance(prop_schema, dict)
else prop_value
)
for key, prop_value in parsed.items() # pyright: ignore[reportUnknownVariableType]
for prop_schema in [properties.get(key)] # type: ignore
}
if schema_type == "array":
parsed = value # pyright: ignore[reportAny]
if isinstance(parsed, str):
try:
parsed = json.loads(parsed) # pyright: ignore[reportAny]
except Exception:
return value # pyright: ignore[reportAny]
if not isinstance(parsed, list):
return value # pyright: ignore[reportAny]
item_schema = schema.get("items")
if not isinstance(item_schema, dict):
return parsed # pyright: ignore[reportUnknownVariableType]
return [_coerce_tool_arg_with_schema(item, item_schema) for item in parsed] # type: ignore
if schema_type == "integer":
if isinstance(value, bool):
return value
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
if isinstance(value, str):
try:
return int(value.strip())
except ValueError:
return value
return value
if schema_type == "number":
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return value
if isinstance(value, str):
try:
num = float(value.strip())
if math.isfinite(num):
return num
except ValueError:
return value
return value
if schema_type == "boolean":
if isinstance(value, bool):
return value
if isinstance(value, str):
lowered = value.strip().lower()
if lowered == "true":
return True
if lowered == "false":
return False
return value
return value # pyright: ignore[reportAny]
def _coerce_tool_calls_to_schema(
tool_calls: list[ToolCallItem], tools: list[dict[str, Any]]
) -> list[ToolCallItem]:
schema_by_name: dict[str, dict[str, Any]] = {}
for tool in tools:
function = tool.get("function")
if not isinstance(function, dict):
continue
name = function.get("name") # type: ignore
parameters = function.get("parameters") # type: ignore
if isinstance(name, str) and isinstance(parameters, dict):
schema_by_name[name] = parameters
if not schema_by_name:
return tool_calls
coerced_calls: list[ToolCallItem] = []
for tool_call in tool_calls:
schema = schema_by_name.get(tool_call.name)
if schema is None:
coerced_calls.append(tool_call)
continue
try:
parsed_args = json.loads(tool_call.arguments) # pyright: ignore[reportAny]
except Exception:
coerced_calls.append(tool_call)
continue
if not isinstance(parsed_args, dict):
coerced_calls.append(tool_call)
continue
coerced_args = _coerce_tool_arg_with_schema(parsed_args, schema) # pyright: ignore[reportAny]
if not isinstance(coerced_args, dict):
coerced_calls.append(tool_call)
continue
coerced_calls.append(
tool_call.model_copy(update={"arguments": json.dumps(coerced_args)})
)
return coerced_calls
def make_mlx_parser( def make_mlx_parser(
@@ -33,7 +204,7 @@ def make_mlx_parser(
return ToolParser( return ToolParser(
start_parsing=tool_call_start, start_parsing=tool_call_start,
end_parsing=tool_call_end, end_parsing=tool_call_end,
parse_tool_calls=parse_tool_calls, _inner_parser=parse_tool_calls,
) )
@@ -62,7 +233,7 @@ def make_json_parser() -> ToolParser:
return ToolParser( return ToolParser(
start_parsing="<tool_call>", start_parsing="<tool_call>",
end_parsing="</tool_call>", end_parsing="</tool_call>",
parse_tool_calls=_parse_json_calls, _inner_parser=_parse_json_calls,
) )

View File

@@ -12,13 +12,22 @@ from anyio import (
) )
from loguru import logger from loguru import logger
from exo.shared.types.chunks import ErrorChunk
from exo.shared.types.events import ( from exo.shared.types.events import (
ChunkGenerated,
Event, Event,
RunnerStatusUpdated, RunnerStatusUpdated,
TaskAcknowledged, TaskAcknowledged,
TaskStatusUpdated, TaskStatusUpdated,
) )
from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.tasks import (
ImageEdits,
ImageGeneration,
Task,
TaskId,
TaskStatus,
TextGeneration,
)
from exo.shared.types.worker.instances import BoundInstance from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import ( from exo.shared.types.worker.runners import (
RunnerConnecting, RunnerConnecting,
@@ -52,6 +61,7 @@ class RunnerSupervisor:
_tg: TaskGroup = field(default_factory=TaskGroup, init=False) _tg: TaskGroup = field(default_factory=TaskGroup, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False) status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False) pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
in_progress: dict[TaskId, Task] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False) completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False) cancelled: set[TaskId] = field(default_factory=set, init=False)
_cancel_watch_runner: anyio.CancelScope = field( _cancel_watch_runner: anyio.CancelScope = field(
@@ -147,9 +157,11 @@ class RunnerSupervisor:
logger.info(f"Starting task {task}") logger.info(f"Starting task {task}")
event = anyio.Event() event = anyio.Event()
self.pending[task.task_id] = event self.pending[task.task_id] = event
self.in_progress[task.task_id] = task
try: try:
await self._task_sender.send_async(task) await self._task_sender.send_async(task)
except ClosedResourceError: except ClosedResourceError:
self.in_progress.pop(task.task_id, None)
logger.warning(f"Task {task} dropped, runner closed communication.") logger.warning(f"Task {task} dropped, runner closed communication.")
return return
await event.wait() await event.wait()
@@ -157,10 +169,17 @@ class RunnerSupervisor:
async def cancel_task(self, task_id: TaskId): async def cancel_task(self, task_id: TaskId):
if task_id in self.completed: if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed") logger.info(f"Unable to cancel {task_id} as it has been completed")
self.cancelled.add(task_id)
return return
self.cancelled.add(task_id) self.cancelled.add(task_id)
with anyio.move_on_after(0.5) as scope: with anyio.move_on_after(0.5) as scope:
await self._cancel_sender.send_async(task_id) try:
await self._cancel_sender.send_async(task_id)
except ClosedResourceError:
# typically occurs when trying to shut down a failed instance
logger.warning(
f"Cancelling task {task_id} failed, runner closed communication"
)
if scope.cancel_called: if scope.cancel_called:
logger.error("RunnerSupervisor cancel pipe blocked") logger.error("RunnerSupervisor cancel pipe blocked")
await self._check_runner(TimeoutError("cancel pipe blocked")) await self._check_runner(TimeoutError("cancel pipe blocked"))
@@ -189,6 +208,7 @@ class RunnerSupervisor:
RunnerShuttingDown, RunnerShuttingDown,
), ),
) )
self.in_progress.pop(event.task_id, None)
self.completed.add(event.task_id) self.completed.add(event.task_id)
await self._event_sender.send(event) await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e: except (ClosedResourceError, BrokenResourceError) as e:
@@ -233,6 +253,22 @@ class RunnerSupervisor:
logger.opt(exception=e).error(f"Runner terminated with {cause}") logger.opt(exception=e).error(f"Runner terminated with {cause}")
for task in self.in_progress.values():
if isinstance(task, (TextGeneration, ImageGeneration, ImageEdits)):
with anyio.CancelScope(shield=True):
await self._event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=ErrorChunk(
model=self.shard_metadata.model_card.model_id,
error_message=(
"Runner shutdown before completing command "
f"({cause})"
),
),
)
)
try: try:
self.status = RunnerFailed(error_message=f"Terminated ({cause})") self.status = RunnerFailed(error_message=f"Terminated ({cause})")
with anyio.CancelScope(shield=True): with anyio.CancelScope(shield=True):

View File

@@ -20,6 +20,8 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance bound_instance: BoundInstance
status: RunnerStatus status: RunnerStatus
completed: set[TaskId] = field(default_factory=set) completed: set[TaskId] = field(default_factory=set)
in_progress: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask): class OtherTask(BaseTask):

View File

@@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import (
encode_messages, encode_messages,
parse_dsml_output, parse_dsml_output,
) )
from exo.worker.runner.llm_inference.runner import parse_deepseek_v32 from exo.worker.runner.llm_inference.model_output_parsers import parse_deepseek_v32
# ── Shared fixtures ────────────────────────────────────────────── # ── Shared fixtures ──────────────────────────────────────────────

View File

@@ -6,6 +6,8 @@ from typing import Callable
import mlx.core as mx import mlx.core as mx
import pytest import pytest
import exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator
import exo.worker.runner.llm_inference.model_output_parsers as mlx_model_output_parsers
import exo.worker.runner.llm_inference.runner as mlx_runner import exo.worker.runner.llm_inference.runner as mlx_runner
from exo.shared.types.chunks import TokenChunk from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import ( from exo.shared.types.events import (
@@ -114,27 +116,41 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a mock group # initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup())) monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer))) monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1)) monkeypatch.setattr(mlx_batch_generator, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin) monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False)) monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
def fake_all_gather(
tasks: list[TextGeneration], group: object
) -> tuple[list[TextGeneration], list[TextGeneration]]:
return (tasks, [])
monkeypatch.setattr(mlx_batch_generator, "mx_all_gather_tasks", fake_all_gather)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1). # Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None. # Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt")) monkeypatch.setattr(
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False)) mlx_batch_generator, "apply_chat_template", make_nothin("test prompt")
)
monkeypatch.setattr(
mlx_model_output_parsers, "detect_thinking_prompt_suffix", make_nothin(False)
)
def fake_generate(*_1: object, **_2: object): def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None) yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate) monkeypatch.setattr(mlx_batch_generator, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness. # Use a fake event_sender to remove test flakiness.
class EventCollector: class EventCollector:
def __init__(self) -> None: def __init__(self, on_event: Callable[[Event], None] | None = None) -> None:
self.events: list[Event] = [] self.events: list[Event] = []
self._on_event = on_event
def send(self, event: Event) -> None: def send(self, event: Event) -> None:
self.events.append(event) self.events.append(event)
if self._on_event:
self._on_event(event)
def close(self) -> None: def close(self) -> None:
pass pass
@@ -159,7 +175,7 @@ class MockGroup:
return 1 return 1
def _run(tasks: Iterable[Task]): def _run(tasks: Iterable[Task], send_after_ready: list[Task] | None = None):
bound_instance = get_bound_mlx_ring_instance( bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID, instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID, model_id=MODEL_A_ID,
@@ -169,7 +185,23 @@ def _run(tasks: Iterable[Task]):
task_sender, task_receiver = mp_channel[Task]() task_sender, task_receiver = mp_channel[Task]()
_cancel_sender, cancel_receiver = mp_channel[TaskId]() _cancel_sender, cancel_receiver = mp_channel[TaskId]()
event_sender = EventCollector()
on_event: Callable[[Event], None] | None = None
if send_after_ready:
_saw_running = False
def _on_event(event: Event) -> None:
nonlocal _saw_running
if isinstance(event, RunnerStatusUpdated):
if isinstance(event.runner_status, RunnerRunning):
_saw_running = True
elif _saw_running and isinstance(event.runner_status, RunnerReady):
for t in send_after_ready:
task_sender.send(t)
on_event = _on_event
event_sender = EventCollector(on_event=on_event)
with task_sender: with task_sender:
for t in tasks: for t in tasks:
@@ -183,18 +215,22 @@ def _run(tasks: Iterable[Task]):
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather", "exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])), make_nothin(mx.array([1])),
): ):
mlx_runner.main( runner = mlx_runner.Runner(
bound_instance, bound_instance,
event_sender, # pyright: ignore[reportArgumentType] event_sender, # pyright: ignore[reportArgumentType]
task_receiver, task_receiver,
cancel_receiver, cancel_receiver,
) )
runner.main()
return event_sender.events return event_sender.events
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch): def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK]) events = _run(
[INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK],
send_after_ready=[SHUTDOWN_TASK],
)
expected_chunk = ChunkGenerated( expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID, command_id=COMMAND_1_ID,

View File

@@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse, GenerationResponse,
ToolCallResponse, ToolCallResponse,
) )
from exo.worker.runner.llm_inference.runner import parse_gpt_oss from exo.worker.runner.llm_inference.model_output_parsers import parse_gpt_oss
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer. # Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
# These are stable since they come from the model's vocabulary. # These are stable since they come from the model's vocabulary.
@@ -107,7 +107,7 @@ def _collect(
def _gen() -> Generator[GenerationResponse, None, None]: def _gen() -> Generator[GenerationResponse, None, None]:
yield from _make_gen_responses(tokens) yield from _make_gen_responses(tokens)
return list(parse_gpt_oss(_gen())) return list(x for x in parse_gpt_oss(_gen()) if x is not None)
def _get_tool_call( def _get_tool_call(

View File

@@ -1,10 +1,11 @@
"""Tests for parse_tool_calls generator, especially unclosed tool call handling.""" """Tests for parse_tool_calls generator, especially unclosed tool call handling."""
import json
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.llm_inference.runner import parse_tool_calls from exo.worker.runner.llm_inference.model_output_parsers import parse_tool_calls
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser
@@ -40,6 +41,7 @@ class TestParseToolCalls:
parse_tool_calls( parse_tool_calls(
_make_responses(texts, finish_on_last=False), _make_responses(texts, finish_on_last=False),
_dummy_parser, _dummy_parser,
tools=None,
) )
) )
@@ -53,6 +55,7 @@ class TestParseToolCalls:
parse_tool_calls( parse_tool_calls(
_make_responses(texts), _make_responses(texts),
_dummy_parser, _dummy_parser,
tools=None,
) )
) )
@@ -77,9 +80,101 @@ class TestParseToolCalls:
parse_tool_calls( parse_tool_calls(
_make_responses(texts, finish_on_last=False), _make_responses(texts, finish_on_last=False),
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser), make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
tools=None,
) )
) )
assert len(results) == 1 assert len(results) == 1
assert isinstance(results[0], GenerationResponse) assert isinstance(results[0], GenerationResponse)
assert results[0].text == "<tool_call>bad content</tool_call>" assert results[0].text == "<tool_call>bad content</tool_call>"
def test_tool_schema_coerces_string_arguments_to_expected_types(self):
"""Tool argument values should be coerced using provided JSON schema."""
def _parser_with_string_args(_text: str) -> dict[str, Any]:
return {
"name": "process",
"arguments": {
"action": "output",
"id": "0",
"verbose": "true",
"temperature": "0.75",
},
}
tools = [
{
"type": "function",
"function": {
"name": "process",
"description": "Manage background processes",
"parameters": {
"type": "object",
"properties": {
"action": {"type": "string"},
"id": {"type": "integer"},
"verbose": {"type": "boolean"},
"temperature": {"type": "number"},
},
"required": ["action"],
},
},
}
]
results = list(
parse_tool_calls(
_make_responses(["<tool_call>", "process", "</tool_call>"]),
make_mlx_parser(
"<tool_call>", "</tool_call>", _parser_with_string_args
),
tools,
)
)
assert len(results) == 1
assert isinstance(results[0], ToolCallResponse)
args = json.loads(results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
assert args == {
"action": "output",
"id": 0,
"verbose": True,
"temperature": 0.75,
}
def test_schema_coercion_skips_unknown_tools(self):
"""If no matching tool schema exists, arguments should remain unchanged."""
def _parser_with_string_id(_text: str) -> dict[str, Any]:
return {
"name": "process",
"arguments": {"action": "output", "id": "0"},
}
tools = [
{
"type": "function",
"function": {
"name": "different_tool",
"parameters": {
"type": "object",
"properties": {"id": {"type": "integer"}},
},
},
}
]
results = list(
parse_tool_calls(
_make_responses(["<tool_call>", "process", "</tool_call>"]),
make_mlx_parser("<tool_call>", "</tool_call>", _parser_with_string_id),
tools,
)
)
assert len(results) == 1
assert isinstance(results[0], ToolCallResponse)
args = json.loads(results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
assert args == {"action": "output", "id": "0"}

View File

@@ -1 +1,93 @@
# TODO: import multiprocessing as mp
from typing import cast
import anyio
import pytest
from exo.shared.models.model_cards import ModelId
from exo.shared.types.chunks import ErrorChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId, TextGeneration
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerFailed, RunnerId
from exo.utils.channels import channel, mp_channel
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
class _DeadProcess:
exitcode = -6
def start(self) -> None:
return None
def is_alive(self) -> bool:
return False
def join(self, _timeout: float | None = None) -> None:
return None
def terminate(self) -> None:
return None
def kill(self) -> None:
return None
@pytest.mark.asyncio
async def test_check_runner_emits_error_chunk_for_inflight_text_generation() -> None:
event_sender, event_receiver = channel[Event]()
task_sender, _ = mp_channel[Task]()
cancel_sender, _ = mp_channel[TaskId]()
_, ev_recv = mp_channel[Event]()
bound_instance: BoundInstance = get_bound_mlx_ring_instance(
instance_id=InstanceId("instance-a"),
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
runner_id=RunnerId("runner-a"),
node_id=NodeId("node-a"),
)
supervisor = RunnerSupervisor(
shard_metadata=bound_instance.bound_shard,
bound_instance=bound_instance,
runner_process=cast("mp.Process", cast(object, _DeadProcess())),
initialize_timeout=400,
_ev_recv=ev_recv,
_task_sender=task_sender,
_event_sender=event_sender,
_cancel_sender=cancel_sender,
)
command_id = CommandId("cmd-a")
task = TextGeneration(
task_id=TaskId("task-a"),
instance_id=bound_instance.instance.instance_id,
command_id=command_id,
task_params=TextGenerationTaskParams(
model=bound_instance.bound_shard.model_card.model_id,
input=[InputMessage(role="user", content="hi")],
stream=True,
),
)
supervisor.in_progress[task.task_id] = task
supervisor.shutdown = lambda: None
await supervisor._check_runner(RuntimeError("boom")) # pyright: ignore[reportPrivateUsage]
got_chunk = await event_receiver.receive()
got_status = await event_receiver.receive()
assert isinstance(got_chunk, ChunkGenerated)
assert got_chunk.command_id == command_id
assert isinstance(got_chunk.chunk, ErrorChunk)
assert "Runner shutdown before completing command" in got_chunk.chunk.error_message
assert isinstance(got_status, RunnerStatusUpdated)
assert isinstance(got_status.runner_status, RunnerFailed)
event_sender.close()
with anyio.move_on_after(0.1):
await event_receiver.aclose()