mirror of
https://github.com/exo-explore/exo.git
synced 2026-03-06 07:06:28 -05:00
Compare commits
12 Commits
leo/prepar
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a36d3968d | ||
|
|
eee3432738 | ||
|
|
e8c3a873a6 | ||
|
|
afab3095b0 | ||
|
|
b9d40e8e35 | ||
|
|
3a4d635d0c | ||
|
|
8485805042 | ||
|
|
4de8f801c7 | ||
|
|
5777bf3c39 | ||
|
|
886192f1e6 | ||
|
|
d914acd64e | ||
|
|
37296c8249 |
@@ -48,7 +48,7 @@ def make_logits_processors(
|
|||||||
logit_bias: Optional[Dict[int, float]] = ...,
|
logit_bias: Optional[Dict[int, float]] = ...,
|
||||||
repetition_penalty: Optional[float] = ...,
|
repetition_penalty: Optional[float] = ...,
|
||||||
repetition_context_size: Optional[int] = ...,
|
repetition_context_size: Optional[int] = ...,
|
||||||
): # -> list[Any]:
|
) -> list[Callable[[mx.array, mx.array], mx.array]]:
|
||||||
"""
|
"""
|
||||||
Make logits processors for use with ``generate_step``.
|
Make logits processors for use with ``generate_step``.
|
||||||
|
|
||||||
|
|||||||
113
CONTRIBUTING.md
113
CONTRIBUTING.md
@@ -39,6 +39,119 @@ Write pure functions where possible. When adding new code, prefer Rust unless th
|
|||||||
|
|
||||||
Run `nix fmt` to auto-format your code before submitting.
|
Run `nix fmt` to auto-format your code before submitting.
|
||||||
|
|
||||||
|
## Model Cards
|
||||||
|
|
||||||
|
EXO uses TOML-based model cards to define model metadata and capabilities. Model cards are stored in:
|
||||||
|
- `resources/inference_model_cards/` for text generation models
|
||||||
|
- `resources/image_model_cards/` for image generation models
|
||||||
|
- `~/.exo/custom_model_cards/` for user-added custom models
|
||||||
|
|
||||||
|
### Adding a Model Card
|
||||||
|
|
||||||
|
To add a new model, create a TOML file with the following structure:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||||
|
n_layers = 16
|
||||||
|
hidden_size = 2048
|
||||||
|
supports_tensor = true
|
||||||
|
tasks = ["TextGeneration"]
|
||||||
|
family = "llama"
|
||||||
|
quantization = "4bit"
|
||||||
|
base_model = "Llama 3.2 1B"
|
||||||
|
capabilities = ["text"]
|
||||||
|
|
||||||
|
[storage_size]
|
||||||
|
in_bytes = 729808896
|
||||||
|
```
|
||||||
|
|
||||||
|
### Required Fields
|
||||||
|
|
||||||
|
- `model_id`: Hugging Face model identifier
|
||||||
|
- `n_layers`: Number of transformer layers
|
||||||
|
- `hidden_size`: Hidden dimension size
|
||||||
|
- `supports_tensor`: Whether the model supports tensor parallelism
|
||||||
|
- `tasks`: List of supported tasks (`TextGeneration`, `TextToImage`, `ImageToImage`)
|
||||||
|
- `family`: Model family (e.g., "llama", "deepseek", "qwen")
|
||||||
|
- `quantization`: Quantization level (e.g., "4bit", "8bit", "bf16")
|
||||||
|
- `base_model`: Human-readable base model name
|
||||||
|
- `capabilities`: List of capabilities (e.g., `["text"]`, `["text", "thinking"]`)
|
||||||
|
|
||||||
|
### Optional Fields
|
||||||
|
|
||||||
|
- `components`: For multi-component models (like image models with separate text encoders and transformers)
|
||||||
|
- `uses_cfg`: Whether the model uses classifier-free guidance (for image models)
|
||||||
|
- `trust_remote_code`: Whether to allow remote code execution (defaults to `false` for security)
|
||||||
|
|
||||||
|
### Capabilities
|
||||||
|
|
||||||
|
The `capabilities` field defines what the model can do:
|
||||||
|
- `text`: Standard text generation
|
||||||
|
- `thinking`: Model supports chain-of-thought reasoning
|
||||||
|
- `thinking_toggle`: Thinking can be enabled/disabled via `enable_thinking` parameter
|
||||||
|
- `image_edit`: Model supports image-to-image editing (FLUX.1-Kontext)
|
||||||
|
|
||||||
|
### Security Note
|
||||||
|
|
||||||
|
By default, `trust_remote_code` is set to `false` for security. Only enable it if the model explicitly requires remote code execution from the Hugging Face hub.
|
||||||
|
|
||||||
|
## API Adapters
|
||||||
|
|
||||||
|
EXO supports multiple API formats through an adapter pattern. Adapters convert API-specific request formats to the internal `TextGenerationTaskParams` format and convert internal token chunks back to API-specific responses.
|
||||||
|
|
||||||
|
### Adapter Architecture
|
||||||
|
|
||||||
|
All adapters live in `src/exo/master/adapters/` and follow the same pattern:
|
||||||
|
|
||||||
|
1. Convert API-specific requests to `TextGenerationTaskParams`
|
||||||
|
2. Handle both streaming and non-streaming response generation
|
||||||
|
3. Convert internal `TokenChunk` objects to API-specific formats
|
||||||
|
4. Manage error handling and edge cases
|
||||||
|
|
||||||
|
### Existing Adapters
|
||||||
|
|
||||||
|
- `chat_completions.py`: OpenAI Chat Completions API
|
||||||
|
- `claude.py`: Anthropic Claude Messages API
|
||||||
|
- `responses.py`: OpenAI Responses API
|
||||||
|
- `ollama.py`: Ollama API (for OpenWebUI compatibility)
|
||||||
|
|
||||||
|
### Adding a New API Adapter
|
||||||
|
|
||||||
|
To add support for a new API format:
|
||||||
|
|
||||||
|
1. Create a new adapter file in `src/exo/master/adapters/`
|
||||||
|
2. Implement a request conversion function:
|
||||||
|
```python
|
||||||
|
def your_api_request_to_text_generation(
|
||||||
|
request: YourAPIRequest,
|
||||||
|
) -> TextGenerationTaskParams:
|
||||||
|
# Convert API request to internal format
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
3. Implement streaming response generation:
|
||||||
|
```python
|
||||||
|
async def generate_your_api_stream(
|
||||||
|
command_id: CommandId,
|
||||||
|
chunk_stream: AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None],
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
# Convert internal chunks to API-specific streaming format
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
4. Implement non-streaming response collection:
|
||||||
|
```python
|
||||||
|
async def collect_your_api_response(
|
||||||
|
command_id: CommandId,
|
||||||
|
chunk_stream: AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None],
|
||||||
|
) -> AsyncGenerator[str]:
|
||||||
|
# Collect all chunks and return single response
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
5. Register the adapter endpoints in `src/exo/master/api.py`
|
||||||
|
|
||||||
|
The adapter pattern keeps API-specific logic isolated from core inference systems. Internal systems (worker, runner, event sourcing) only see `TextGenerationTaskParams` and `TokenChunk` objects - no API-specific types cross the adapter boundary.
|
||||||
|
|
||||||
|
For detailed API documentation, see [docs/api.md](docs/api.md).
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
EXO relies heavily on manual testing at this point in the project, but this is evolving. Before submitting a change, test both before and after to demonstrate how your change improves behavior. Do the best you can with the hardware you have available - if you need help testing, ask and we'll do our best to assist. Add automated tests where possible - we're actively working to substantially improve our automated testing story.
|
EXO relies heavily on manual testing at this point in the project, but this is evolving. Before submitting a change, test both before and after to demonstrate how your change improves behavior. Do the best you can with the hardware you have available - if you need help testing, ask and we'll do our best to assist. Add automated tests where possible - we're actively working to substantially improve our automated testing story.
|
||||||
|
|||||||
118
README.md
118
README.md
@@ -26,6 +26,8 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
|
|||||||
- **Topology-Aware Auto Parallel**: exo figures out the best way to split your model across all available devices based on a realtime view of your device topology. It takes into account device resources and network latency/bandwidth between each link.
|
- **Topology-Aware Auto Parallel**: exo figures out the best way to split your model across all available devices based on a realtime view of your device topology. It takes into account device resources and network latency/bandwidth between each link.
|
||||||
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
|
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
|
||||||
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
|
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
|
||||||
|
- **Multiple API Compatibility**: Compatible with OpenAI Chat Completions API, Claude Messages API, OpenAI Responses API, and Ollama API - use your existing tools and clients.
|
||||||
|
- **Custom Model Support**: Load custom models from HuggingFace hub to expand the range of available models.
|
||||||
|
|
||||||
## Dashboard
|
## Dashboard
|
||||||
|
|
||||||
@@ -196,6 +198,8 @@ exo follows the [XDG Base Directory Specification](https://specifications.freede
|
|||||||
- **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`)
|
- **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`)
|
||||||
- **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`)
|
- **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`)
|
||||||
- **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`)
|
- **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`)
|
||||||
|
- **Log files**: `~/.cache/exo/exo_log/` (with automatic log rotation)
|
||||||
|
- **Custom model cards**: `~/.local/share/exo/custom_model_cards/`
|
||||||
|
|
||||||
You can override these locations by setting the corresponding XDG environment variables.
|
You can override these locations by setting the corresponding XDG environment variables.
|
||||||
|
|
||||||
@@ -275,8 +279,47 @@ After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
exo supports several environment variables for configuration:
|
||||||
|
|
||||||
|
| Variable | Description | Default |
|
||||||
|
|----------|-------------|---------|
|
||||||
|
| `EXO_MODELS_PATH` | Colon-separated paths to search for pre-downloaded models (e.g., on NFS mounts or shared storage) | None |
|
||||||
|
| `EXO_MODELS_DIR` | Directory where exo downloads and stores models | `~/.local/share/exo/models` (Linux) or `~/.exo/models` (macOS) |
|
||||||
|
| `EXO_OFFLINE` | Run without internet connection (uses only local models) | `false` |
|
||||||
|
| `EXO_ENABLE_IMAGE_MODELS` | Enable image model support | `false` |
|
||||||
|
| `EXO_LIBP2P_NAMESPACE` | Custom namespace for cluster isolation | None |
|
||||||
|
| `EXO_FAST_SYNCH` | Control MLX_METAL_FAST_SYNCH behavior (for JACCL backend) | Auto |
|
||||||
|
| `EXO_TRACING_ENABLED` | Enable distributed tracing for performance analysis | `false` |
|
||||||
|
|
||||||
|
**Example usage:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use pre-downloaded models from NFS mount
|
||||||
|
EXO_MODELS_PATH=/mnt/nfs/models:/opt/ai-models uv run exo
|
||||||
|
|
||||||
|
# Run in offline mode
|
||||||
|
EXO_OFFLINE=true uv run exo
|
||||||
|
|
||||||
|
# Enable image models
|
||||||
|
EXO_ENABLE_IMAGE_MODELS=true uv run exo
|
||||||
|
|
||||||
|
# Use custom namespace for cluster isolation
|
||||||
|
EXO_LIBP2P_NAMESPACE=my-dev-cluster uv run exo
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
### Using the API
|
### Using the API
|
||||||
|
|
||||||
|
exo provides multiple API-compatible interfaces for maximum compatibility with existing tools:
|
||||||
|
|
||||||
|
- **OpenAI Chat Completions API** - Compatible with OpenAI clients
|
||||||
|
- **Claude Messages API** - Compatible with Anthropic's Claude format
|
||||||
|
- **OpenAI Responses API** - Compatible with OpenAI's Responses format
|
||||||
|
- **Ollama API** - Compatible with Ollama and tools like OpenWebUI
|
||||||
|
|
||||||
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.
|
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -366,14 +409,85 @@ When you're done, delete the instance by its ID (find it via `/state` or `/insta
|
|||||||
curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
|
curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Claude Messages API Compatibility
|
||||||
|
|
||||||
|
Use the Claude Messages API format with the `/v1/messages` endpoint:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -N -X POST http://localhost:52415/v1/messages \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"stream": true
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI Responses API Compatibility
|
||||||
|
|
||||||
|
Use the OpenAI Responses API format with the `/v1/responses` endpoint:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -N -X POST http://localhost:52415/v1/responses \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": true
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Ollama API Compatibility
|
||||||
|
|
||||||
|
exo supports Ollama API endpoints for compatibility with tools like OpenWebUI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ollama chat
|
||||||
|
curl -X POST http://localhost:52415/ollama/api/chat \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": false
|
||||||
|
}'
|
||||||
|
|
||||||
|
# List models (Ollama format)
|
||||||
|
curl http://localhost:52415/ollama/api/tags
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Model Loading from HuggingFace
|
||||||
|
|
||||||
|
You can add custom models from the HuggingFace hub:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:52415/models/add \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model_id": "mlx-community/my-custom-model"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Security Note:**
|
||||||
|
|
||||||
|
Custom models requiring `trust_remote_code` in their configuration must be explicitly enabled (default is false) for security. Only enable this if you trust the model's remote code execution. Models are fetched from HuggingFace and stored locally as custom model cards.
|
||||||
|
|
||||||
**Other useful API endpoints*:**
|
**Other useful API endpoints*:**
|
||||||
|
|
||||||
- List all models: `curl http://localhost:52415/models`
|
- List all models: `curl http://localhost:52415/models`
|
||||||
|
- List downloaded models only: `curl http://localhost:52415/models?status=downloaded`
|
||||||
|
- Search HuggingFace: `curl "http://localhost:52415/models/search?query=llama&limit=10"`
|
||||||
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
|
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
|
||||||
|
|
||||||
For further details, see:
|
For further details, see:
|
||||||
|
|
||||||
- API basic documentation in [docs/api.md](docs/api.md).
|
- API documentation in [docs/api.md](docs/api.md).
|
||||||
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -432,4 +546,4 @@ On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working
|
|||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
|
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
|
||||||
|
|||||||
@@ -36,8 +36,12 @@
|
|||||||
return num.toString();
|
return num.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract model name from full ID (e.g., "mlx-community/Llama-3.2-1B" -> "Llama-3.2-1B")
|
// Show short name for mlx-community models, full ID for everything else
|
||||||
const modelName = $derived(model.id.split("/").pop() || model.id);
|
const modelName = $derived(
|
||||||
|
model.author === "mlx-community"
|
||||||
|
? model.id.split("/").pop() || model.id
|
||||||
|
: model.id,
|
||||||
|
);
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
|
|||||||
@@ -507,9 +507,29 @@
|
|||||||
});
|
});
|
||||||
|
|
||||||
$effect(() => {
|
$effect(() => {
|
||||||
if (containerRef && processedHtml) {
|
if (!containerRef || !browser) return;
|
||||||
setupCopyButtons();
|
|
||||||
|
function handleDelegatedClick(event: MouseEvent) {
|
||||||
|
const codeBtn = (event.target as HTMLElement).closest(
|
||||||
|
".copy-code-btn",
|
||||||
|
) as HTMLButtonElement | null;
|
||||||
|
if (codeBtn) {
|
||||||
|
handleCopyClick({ currentTarget: codeBtn } as unknown as Event);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const mathBtn = (event.target as HTMLElement).closest(
|
||||||
|
".copy-math-btn",
|
||||||
|
) as HTMLButtonElement | null;
|
||||||
|
if (mathBtn) {
|
||||||
|
handleMathCopyClick({ currentTarget: mathBtn } as unknown as Event);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
containerRef.addEventListener("click", handleDelegatedClick);
|
||||||
|
return () => {
|
||||||
|
containerRef?.removeEventListener("click", handleDelegatedClick);
|
||||||
|
};
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@
|
|||||||
group: ModelGroup;
|
group: ModelGroup;
|
||||||
isExpanded: boolean;
|
isExpanded: boolean;
|
||||||
isFavorite: boolean;
|
isFavorite: boolean;
|
||||||
|
isHighlighted?: boolean;
|
||||||
selectedModelId: string | null;
|
selectedModelId: string | null;
|
||||||
canModelFit: (id: string) => boolean;
|
canModelFit: (id: string) => boolean;
|
||||||
getModelFitStatus: (id: string) => ModelFitStatus;
|
getModelFitStatus: (id: string) => ModelFitStatus;
|
||||||
@@ -48,6 +49,7 @@
|
|||||||
group,
|
group,
|
||||||
isExpanded,
|
isExpanded,
|
||||||
isFavorite,
|
isFavorite,
|
||||||
|
isHighlighted = false,
|
||||||
selectedModelId,
|
selectedModelId,
|
||||||
canModelFit,
|
canModelFit,
|
||||||
getModelFitStatus,
|
getModelFitStatus,
|
||||||
@@ -150,10 +152,11 @@
|
|||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
|
data-model-ids={group.variants.map((v) => v.id).join(" ")}
|
||||||
class="border-b border-white/5 last:border-b-0 {!anyVariantFits &&
|
class="border-b border-white/5 last:border-b-0 {!anyVariantFits &&
|
||||||
!anyVariantHasInstance
|
!anyVariantHasInstance
|
||||||
? 'opacity-50'
|
? 'opacity-50'
|
||||||
: ''}"
|
: ''} {isHighlighted ? 'model-just-added' : ''}"
|
||||||
>
|
>
|
||||||
<!-- Main row -->
|
<!-- Main row -->
|
||||||
<div
|
<div
|
||||||
@@ -644,3 +647,21 @@
|
|||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<style>
|
||||||
|
.model-just-added {
|
||||||
|
animation: highlightFade 4s ease-out forwards;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes highlightFade {
|
||||||
|
0%,
|
||||||
|
40% {
|
||||||
|
background-color: rgba(20, 83, 45, 0.25);
|
||||||
|
box-shadow: inset 0 0 0 1px rgba(74, 222, 128, 0.4);
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
background-color: transparent;
|
||||||
|
box-shadow: none;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
|
import { tick } from "svelte";
|
||||||
import { fade, fly } from "svelte/transition";
|
import { fade, fly } from "svelte/transition";
|
||||||
import { cubicOut } from "svelte/easing";
|
import { cubicOut } from "svelte/easing";
|
||||||
import FamilySidebar from "./FamilySidebar.svelte";
|
import FamilySidebar from "./FamilySidebar.svelte";
|
||||||
@@ -7,6 +8,7 @@
|
|||||||
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
|
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
|
||||||
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
|
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
|
||||||
import { getRecentEntries } from "$lib/stores/recents.svelte";
|
import { getRecentEntries } from "$lib/stores/recents.svelte";
|
||||||
|
import { addToast } from "$lib/stores/toast.svelte";
|
||||||
|
|
||||||
interface ModelInfo {
|
interface ModelInfo {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -191,6 +193,13 @@
|
|||||||
let hfSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;
|
let hfSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;
|
||||||
let manualModelId = $state("");
|
let manualModelId = $state("");
|
||||||
let addModelError = $state<string | null>(null);
|
let addModelError = $state<string | null>(null);
|
||||||
|
let justAddedModelId = $state<string | null>(null);
|
||||||
|
let justAddedTimer: ReturnType<typeof setTimeout> | null = null;
|
||||||
|
|
||||||
|
// Inline HuggingFace search in main search bar
|
||||||
|
let mainSearchHfResults = $state<HuggingFaceModel[]>([]);
|
||||||
|
let mainSearchHfLoading = $state(false);
|
||||||
|
let mainSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;
|
||||||
|
|
||||||
// Reset transient state when modal opens, but preserve tab selection
|
// Reset transient state when modal opens, but preserve tab selection
|
||||||
$effect(() => {
|
$effect(() => {
|
||||||
@@ -200,6 +209,11 @@
|
|||||||
showFilters = false;
|
showFilters = false;
|
||||||
manualModelId = "";
|
manualModelId = "";
|
||||||
addModelError = null;
|
addModelError = null;
|
||||||
|
justAddedModelId = null;
|
||||||
|
if (justAddedTimer) {
|
||||||
|
clearTimeout(justAddedTimer);
|
||||||
|
justAddedTimer = null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -214,6 +228,50 @@
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Inline HuggingFace search when local search returns no results
|
||||||
|
$effect(() => {
|
||||||
|
const query = searchQuery.trim();
|
||||||
|
const noLocalResults = filteredGroups.length === 0;
|
||||||
|
|
||||||
|
if (mainSearchDebounceTimer) {
|
||||||
|
clearTimeout(mainSearchDebounceTimer);
|
||||||
|
mainSearchDebounceTimer = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
selectedFamily === "huggingface" ||
|
||||||
|
selectedFamily === "recents" ||
|
||||||
|
selectedFamily === "favorites" ||
|
||||||
|
query.length < 2 ||
|
||||||
|
!noLocalResults
|
||||||
|
) {
|
||||||
|
mainSearchHfResults = [];
|
||||||
|
mainSearchHfLoading = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
mainSearchHfLoading = true;
|
||||||
|
mainSearchDebounceTimer = setTimeout(async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(
|
||||||
|
`/models/search?query=${encodeURIComponent(query)}&limit=10`,
|
||||||
|
);
|
||||||
|
if (response.ok) {
|
||||||
|
const results: HuggingFaceModel[] = await response.json();
|
||||||
|
mainSearchHfResults = results.filter(
|
||||||
|
(r) => !existingModelIds.has(r.id),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
mainSearchHfResults = [];
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
mainSearchHfResults = [];
|
||||||
|
} finally {
|
||||||
|
mainSearchHfLoading = false;
|
||||||
|
}
|
||||||
|
}, 500);
|
||||||
|
});
|
||||||
|
|
||||||
async function fetchTrendingModels() {
|
async function fetchTrendingModels() {
|
||||||
hfIsLoadingTrending = true;
|
hfIsLoadingTrending = true;
|
||||||
try {
|
try {
|
||||||
@@ -274,6 +332,24 @@
|
|||||||
addModelError = null;
|
addModelError = null;
|
||||||
try {
|
try {
|
||||||
await onAddModel(modelId);
|
await onAddModel(modelId);
|
||||||
|
// Success: show toast, switch to All Models, highlight the model
|
||||||
|
const shortName = modelId.split("/").pop() || modelId;
|
||||||
|
addToast({ type: "success", message: `Added ${shortName}` });
|
||||||
|
justAddedModelId = modelId;
|
||||||
|
selectedFamily = null;
|
||||||
|
searchQuery = "";
|
||||||
|
// Scroll to the newly added model after DOM update
|
||||||
|
await tick();
|
||||||
|
const el = document.querySelector(
|
||||||
|
`[data-model-ids~="${CSS.escape(modelId)}"]`,
|
||||||
|
);
|
||||||
|
el?.scrollIntoView({ behavior: "smooth", block: "center" });
|
||||||
|
// Clear highlight after 4 seconds
|
||||||
|
if (justAddedTimer) clearTimeout(justAddedTimer);
|
||||||
|
justAddedTimer = setTimeout(() => {
|
||||||
|
justAddedModelId = null;
|
||||||
|
justAddedTimer = null;
|
||||||
|
}, 4000);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
addModelError =
|
addModelError =
|
||||||
error instanceof Error ? error.message : "Failed to add model";
|
error instanceof Error ? error.message : "Failed to add model";
|
||||||
@@ -841,6 +917,8 @@
|
|||||||
{group}
|
{group}
|
||||||
isExpanded={expandedGroups.has(group.id)}
|
isExpanded={expandedGroups.has(group.id)}
|
||||||
isFavorite={favorites.has(group.id)}
|
isFavorite={favorites.has(group.id)}
|
||||||
|
isHighlighted={justAddedModelId !== null &&
|
||||||
|
group.variants.some((v) => v.id === justAddedModelId)}
|
||||||
{selectedModelId}
|
{selectedModelId}
|
||||||
{canModelFit}
|
{canModelFit}
|
||||||
{getModelFitStatus}
|
{getModelFitStatus}
|
||||||
@@ -910,6 +988,8 @@
|
|||||||
{group}
|
{group}
|
||||||
isExpanded={expandedGroups.has(group.id)}
|
isExpanded={expandedGroups.has(group.id)}
|
||||||
isFavorite={favorites.has(group.id)}
|
isFavorite={favorites.has(group.id)}
|
||||||
|
isHighlighted={justAddedModelId !== null &&
|
||||||
|
group.variants.some((v) => v.id === justAddedModelId)}
|
||||||
{selectedModelId}
|
{selectedModelId}
|
||||||
{canModelFit}
|
{canModelFit}
|
||||||
{getModelFitStatus}
|
{getModelFitStatus}
|
||||||
@@ -937,6 +1017,8 @@
|
|||||||
{group}
|
{group}
|
||||||
isExpanded={expandedGroups.has(group.id)}
|
isExpanded={expandedGroups.has(group.id)}
|
||||||
isFavorite={favorites.has(group.id)}
|
isFavorite={favorites.has(group.id)}
|
||||||
|
isHighlighted={justAddedModelId !== null &&
|
||||||
|
group.variants.some((v) => v.id === justAddedModelId)}
|
||||||
{selectedModelId}
|
{selectedModelId}
|
||||||
{canModelFit}
|
{canModelFit}
|
||||||
{getModelFitStatus}
|
{getModelFitStatus}
|
||||||
@@ -948,6 +1030,55 @@
|
|||||||
{instanceStatuses}
|
{instanceStatuses}
|
||||||
/>
|
/>
|
||||||
{/each}
|
{/each}
|
||||||
|
<!-- Inline HuggingFace search results (shown when no local results match) -->
|
||||||
|
{#if filteredGroups.length === 0 && searchQuery.trim().length >= 2 && selectedFamily !== "huggingface" && selectedFamily !== "recents" && selectedFamily !== "favorites"}
|
||||||
|
{#if mainSearchHfLoading}
|
||||||
|
<div
|
||||||
|
class="flex items-center gap-2 px-3 py-2 border-t border-orange-400/20 bg-orange-950/20"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
class="w-4 h-4 border-2 border-orange-400 border-t-transparent rounded-full animate-spin"
|
||||||
|
></span>
|
||||||
|
<span class="text-xs font-mono text-orange-400/60"
|
||||||
|
>Searching HuggingFace...</span
|
||||||
|
>
|
||||||
|
</div>
|
||||||
|
{:else if mainSearchHfResults.length > 0}
|
||||||
|
<div
|
||||||
|
class="sticky top-0 z-10 flex items-center gap-2 px-3 py-2 bg-orange-950/30 border-y border-orange-400/20 backdrop-blur-sm"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
class="text-xs font-mono text-orange-400 tracking-wider uppercase"
|
||||||
|
>From HuggingFace</span
|
||||||
|
>
|
||||||
|
</div>
|
||||||
|
{#each mainSearchHfResults as model}
|
||||||
|
<HuggingFaceResultItem
|
||||||
|
{model}
|
||||||
|
isAdded={existingModelIds.has(model.id)}
|
||||||
|
isAdding={addingModelId === model.id}
|
||||||
|
onAdd={() => handleAddModel(model.id)}
|
||||||
|
onSelect={() => handleSelectHfModel(model.id)}
|
||||||
|
downloadedOnNodes={downloadsData
|
||||||
|
? getNodesWithModelDownloaded(downloadsData, model.id).map(
|
||||||
|
getNodeName,
|
||||||
|
)
|
||||||
|
: []}
|
||||||
|
/>
|
||||||
|
{/each}
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="w-full px-3 py-2 text-xs font-mono text-orange-400/60 hover:text-orange-400 hover:bg-orange-500/10 transition-colors text-center"
|
||||||
|
onclick={() => {
|
||||||
|
hfSearchQuery = searchQuery;
|
||||||
|
searchHuggingFace(searchQuery);
|
||||||
|
selectedFamily = "huggingface";
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
See all results on Hub
|
||||||
|
</button>
|
||||||
|
{/if}
|
||||||
|
{/if}
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
476
docs/api.md
476
docs/api.md
@@ -4,7 +4,7 @@ This document describes the REST API exposed by the **EXO** service, as implemen
|
|||||||
|
|
||||||
`src/exo/master/api.py`
|
`src/exo/master/api.py`
|
||||||
|
|
||||||
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface.
|
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using multiple API-compatible interfaces.
|
||||||
|
|
||||||
Base URL example:
|
Base URL example:
|
||||||
|
|
||||||
@@ -144,8 +144,59 @@ Placement result.
|
|||||||
|
|
||||||
Returns the list of available models and their metadata.
|
Returns the list of available models and their metadata.
|
||||||
|
|
||||||
|
**Query parameters:**
|
||||||
|
|
||||||
|
* `status`: string (optional) - Filter by `downloaded` to show only downloaded models
|
||||||
|
|
||||||
**Response:**
|
**Response:**
|
||||||
Array of model descriptors.
|
Array of model descriptors including `is_custom` field for custom HuggingFace models.
|
||||||
|
|
||||||
|
### Add Custom Model
|
||||||
|
|
||||||
|
**POST** `/models/add`
|
||||||
|
|
||||||
|
Add a custom model from HuggingFace hub.
|
||||||
|
|
||||||
|
**Request body (example):**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model_id": "mlx-community/my-custom-model"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Model descriptor for the added model.
|
||||||
|
|
||||||
|
**Security note:**
|
||||||
|
Models with `trust_remote_code` enabled in their configuration require explicit opt-in (default is false) for security.
|
||||||
|
|
||||||
|
### Delete Custom Model
|
||||||
|
|
||||||
|
**DELETE** `/models/custom/{model_id}`
|
||||||
|
|
||||||
|
Delete a user-added custom model card.
|
||||||
|
|
||||||
|
**Path parameters:**
|
||||||
|
|
||||||
|
* `model_id`: string, ID of the custom model to delete
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Confirmation JSON with deleted model ID.
|
||||||
|
|
||||||
|
### Search Models
|
||||||
|
|
||||||
|
**GET** `/models/search`
|
||||||
|
|
||||||
|
Search HuggingFace Hub for mlx-community models.
|
||||||
|
|
||||||
|
**Query parameters:**
|
||||||
|
|
||||||
|
* `query`: string (optional) - Search query
|
||||||
|
* `limit`: integer (default: 20) - Maximum number of results
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Array of HuggingFace model search results.
|
||||||
|
|
||||||
## 4. Inference / Chat Completions
|
## 4. Inference / Chat Completions
|
||||||
|
|
||||||
@@ -168,9 +219,123 @@ Executes a chat completion request using an OpenAI-compatible schema. Supports s
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Request parameters:**
|
||||||
|
|
||||||
|
* `model`: string, required - Model ID to use
|
||||||
|
* `messages`: array, required - Conversation messages
|
||||||
|
* `stream`: boolean (default: false) - Enable streaming responses
|
||||||
|
* `max_tokens`: integer (optional) - Maximum tokens to generate
|
||||||
|
* `temperature`: float (optional) - Sampling temperature
|
||||||
|
* `top_p`: float (optional) - Nucleus sampling parameter
|
||||||
|
* `top_k`: integer (optional) - Top-k sampling parameter
|
||||||
|
* `stop`: string or array (optional) - Stop sequences
|
||||||
|
* `seed`: integer (optional) - Random seed for reproducibility
|
||||||
|
* `enable_thinking`: boolean (optional) - Enable thinking mode for capable models (DeepSeek V3.1, Qwen3, GLM-4.7)
|
||||||
|
* `tools`: array (optional) - Tool definitions for function calling
|
||||||
|
* `logprobs`: boolean (optional) - Return log probabilities
|
||||||
|
* `top_logprobs`: integer (optional) - Number of top log probabilities to return
|
||||||
|
|
||||||
**Response:**
|
**Response:**
|
||||||
OpenAI-compatible chat completion response.
|
OpenAI-compatible chat completion response.
|
||||||
|
|
||||||
|
**Streaming response format:**
|
||||||
|
When `stream=true`, returns Server-Sent Events (SSE) with format:
|
||||||
|
|
||||||
|
```
|
||||||
|
data: {"id":"...","object":"chat.completion","created":...,"model":"...","choices":[...]}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Non-streaming response includes usage statistics:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "...",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "llama-3.2-1b",
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I help you?"
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 15,
|
||||||
|
"completion_tokens": 8,
|
||||||
|
"total_tokens": 23
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Cancellation:**
|
||||||
|
You can cancel an active generation by closing the HTTP connection. The server detects the disconnection and stops processing.
|
||||||
|
|
||||||
|
### Claude Messages API
|
||||||
|
|
||||||
|
**POST** `/v1/messages`
|
||||||
|
|
||||||
|
Executes a chat completion request using the Claude Messages API format. Supports streaming and non-streaming modes.
|
||||||
|
|
||||||
|
**Request body (example):**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama-3.2-1b",
|
||||||
|
"messages": [
|
||||||
|
{ "role": "user", "content": "Hello" }
|
||||||
|
],
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"stream": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Streaming response format:**
|
||||||
|
When `stream=true`, returns Server-Sent Events with Claude-specific event types:
|
||||||
|
|
||||||
|
* `message_start` - Message generation started
|
||||||
|
* `content_block_start` - Content block started
|
||||||
|
* `content_block_delta` - Incremental content chunk
|
||||||
|
* `content_block_stop` - Content block completed
|
||||||
|
* `message_delta` - Message metadata updates
|
||||||
|
* `message_stop` - Message generation completed
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Claude-compatible messages response.
|
||||||
|
|
||||||
|
### OpenAI Responses API
|
||||||
|
|
||||||
|
**POST** `/v1/responses`
|
||||||
|
|
||||||
|
Executes a chat completion request using the OpenAI Responses API format. Supports streaming and non-streaming modes.
|
||||||
|
|
||||||
|
**Request body (example):**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama-3.2-1b",
|
||||||
|
"messages": [
|
||||||
|
{ "role": "user", "content": "Hello" }
|
||||||
|
],
|
||||||
|
"stream": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Streaming response format:**
|
||||||
|
When `stream=true`, returns Server-Sent Events with response-specific event types:
|
||||||
|
|
||||||
|
* `response.created` - Response generation started
|
||||||
|
* `response.in_progress` - Response is being generated
|
||||||
|
* `response.output_item.added` - New output item added
|
||||||
|
* `response.output_item.done` - Output item completed
|
||||||
|
* `response.done` - Response generation completed
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
OpenAI Responses API-compatible response.
|
||||||
|
|
||||||
### Benchmarked Chat Completions
|
### Benchmarked Chat Completions
|
||||||
|
|
||||||
**POST** `/bench/chat/completions`
|
**POST** `/bench/chat/completions`
|
||||||
@@ -181,7 +346,13 @@ Same as `/v1/chat/completions`, but also returns performance and generation stat
|
|||||||
Same schema as `/v1/chat/completions`.
|
Same schema as `/v1/chat/completions`.
|
||||||
|
|
||||||
**Response:**
|
**Response:**
|
||||||
Chat completion plus benchmarking metrics.
|
Chat completion plus benchmarking metrics including:
|
||||||
|
|
||||||
|
* `prompt_tps` - Tokens per second during prompt processing
|
||||||
|
* `generation_tps` - Tokens per second during generation
|
||||||
|
* `prompt_tokens` - Number of prompt tokens
|
||||||
|
* `generation_tokens` - Number of generated tokens
|
||||||
|
* `peak_memory_usage` - Peak memory used during generation
|
||||||
|
|
||||||
### Cancel Command
|
### Cancel Command
|
||||||
|
|
||||||
@@ -204,24 +375,148 @@ Cancels an active generation command (text or image). Notifies workers and close
|
|||||||
|
|
||||||
Returns 404 if the command is not found or already completed.
|
Returns 404 if the command is not found or already completed.
|
||||||
|
|
||||||
## 5. Image Generation & Editing
|
## 5. Ollama API Compatibility
|
||||||
|
|
||||||
|
EXO provides Ollama API compatibility for tools like OpenWebUI.
|
||||||
|
|
||||||
|
### Ollama Chat
|
||||||
|
|
||||||
|
**POST** `/ollama/api/chat`
|
||||||
|
**POST** `/ollama/api/api/chat` (alias)
|
||||||
|
**POST** `/ollama/api/v1/chat` (alias)
|
||||||
|
|
||||||
|
Execute a chat request using Ollama API format.
|
||||||
|
|
||||||
|
**Request body (example):**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama-3.2-1b",
|
||||||
|
"messages": [
|
||||||
|
{ "role": "user", "content": "Hello" }
|
||||||
|
],
|
||||||
|
"stream": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Ollama-compatible chat response.
|
||||||
|
|
||||||
|
### Ollama Generate
|
||||||
|
|
||||||
|
**POST** `/ollama/api/generate`
|
||||||
|
|
||||||
|
Execute a text generation request using Ollama API format.
|
||||||
|
|
||||||
|
**Request body (example):**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama-3.2-1b",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Ollama-compatible generation response.
|
||||||
|
|
||||||
|
### Ollama Tags
|
||||||
|
|
||||||
|
**GET** `/ollama/api/tags`
|
||||||
|
**GET** `/ollama/api/api/tags` (alias)
|
||||||
|
**GET** `/ollama/api/v1/tags` (alias)
|
||||||
|
|
||||||
|
Returns list of downloaded models in Ollama tags format.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Array of model tags with metadata.
|
||||||
|
|
||||||
|
### Ollama Show
|
||||||
|
|
||||||
|
**POST** `/ollama/api/show`
|
||||||
|
|
||||||
|
Returns model information in Ollama show format.
|
||||||
|
|
||||||
|
**Request body:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "llama-3.2-1b"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Model details including modelfile and family.
|
||||||
|
|
||||||
|
### Ollama PS
|
||||||
|
|
||||||
|
**GET** `/ollama/api/ps`
|
||||||
|
|
||||||
|
Returns list of running models (active instances).
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Array of active model instances.
|
||||||
|
|
||||||
|
### Ollama Version
|
||||||
|
|
||||||
|
**GET** `/ollama/api/version`
|
||||||
|
**HEAD** `/ollama/` (alias)
|
||||||
|
**HEAD** `/ollama/api/version` (alias)
|
||||||
|
|
||||||
|
Returns version information for Ollama API compatibility.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": "exo v1.0"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 6. Image Generation & Editing
|
||||||
|
|
||||||
### Image Generation
|
### Image Generation
|
||||||
|
|
||||||
**POST** `/v1/images/generations`
|
**POST** `/v1/images/generations`
|
||||||
|
|
||||||
Executes an image generation request using an OpenAI-compatible schema with additional advanced_params.
|
Executes an image generation request using an OpenAI-compatible schema with additional advanced_params. Supports both streaming and non-streaming modes.
|
||||||
|
|
||||||
**Request body (example):**
|
**Request body (example):**
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"prompt": "a robot playing chess",
|
"prompt": "a robot playing chess",
|
||||||
"model": "flux-dev",
|
"model": "exolabs/FLUX.1-dev",
|
||||||
|
"n": 1,
|
||||||
|
"size": "1024x1024",
|
||||||
"stream": false,
|
"stream": false,
|
||||||
|
"response_format": "b64_json"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Request parameters:**
|
||||||
|
|
||||||
|
* `prompt`: string, required - Text description of the image
|
||||||
|
* `model`: string, required - Image model ID
|
||||||
|
* `n`: integer (default: 1) - Number of images to generate
|
||||||
|
* `size`: string (default: "auto") - Image dimensions. Supported sizes:
|
||||||
|
- `512x512`
|
||||||
|
- `768x768`
|
||||||
|
- `1024x768`
|
||||||
|
- `768x1024`
|
||||||
|
- `1024x1024`
|
||||||
|
- `1024x1536`
|
||||||
|
- `1536x1024`
|
||||||
|
- `1024x1365`
|
||||||
|
- `1365x1024`
|
||||||
|
* `stream`: boolean (default: false) - Enable streaming for partial images
|
||||||
|
* `partial_images`: integer (default: 0) - Number of partial images to stream during generation
|
||||||
|
* `response_format`: string (default: "b64_json") - Either `url` or `b64_json`
|
||||||
|
* `quality`: string (default: "medium") - Either `high`, `medium`, or `low`
|
||||||
|
* `output_format`: string (default: "png") - Either `png`, `jpeg`, or `webp`
|
||||||
|
* `advanced_params`: object (optional) - Advanced generation parameters
|
||||||
|
|
||||||
**Advanced Parameters (`advanced_params`):**
|
**Advanced Parameters (`advanced_params`):**
|
||||||
|
|
||||||
| Parameter | Type | Constraints | Description |
|
| Parameter | Type | Constraints | Description |
|
||||||
@@ -231,8 +526,54 @@ Executes an image generation request using an OpenAI-compatible schema with addi
|
|||||||
| `guidance` | float | 1.0-20.0 | Classifier-free guidance scale |
|
| `guidance` | float | 1.0-20.0 | Classifier-free guidance scale |
|
||||||
| `negative_prompt` | string | - | Text describing what to avoid in the image |
|
| `negative_prompt` | string | - | Text describing what to avoid in the image |
|
||||||
|
|
||||||
|
**Non-streaming response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"created": 1234567890,
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"b64_json": "iVBORw0KGgoAAAANSUhEUgAA...",
|
||||||
|
"url": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Streaming response format:**
|
||||||
|
When `stream=true` and `partial_images > 0`, returns Server-Sent Events:
|
||||||
|
|
||||||
|
```
|
||||||
|
data: {"type":"partial","image_index":0,"partial_index":1,"total_partials":5,"format":"png","data":{"b64_json":"..."}}
|
||||||
|
|
||||||
|
data: {"type":"final","image_index":0,"format":"png","data":{"b64_json":"..."}}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Image Editing
|
||||||
|
|
||||||
|
**POST** `/v1/images/edits`
|
||||||
|
|
||||||
|
Executes an image editing request (img2img) using FLUX.1-Kontext-dev or similar models.
|
||||||
|
|
||||||
|
**Request (multipart/form-data):**
|
||||||
|
|
||||||
|
* `image`: file, required - Input image to edit
|
||||||
|
* `prompt`: string, required - Text description of desired changes
|
||||||
|
* `model`: string, required - Image editing model ID (e.g., `exolabs/FLUX.1-Kontext-dev`)
|
||||||
|
* `n`: integer (default: 1) - Number of edited images to generate
|
||||||
|
* `size`: string (optional) - Output image dimensions
|
||||||
|
* `response_format`: string (default: "b64_json") - Either `url` or `b64_json`
|
||||||
|
* `input_fidelity`: string (default: "low") - Either `low` or `high` - Controls how closely the output follows the input image
|
||||||
|
* `stream`: string (default: "false") - Enable streaming
|
||||||
|
* `partial_images`: string (default: "0") - Number of partial images to stream
|
||||||
|
* `quality`: string (default: "medium") - Either `high`, `medium`, or `low`
|
||||||
|
* `output_format`: string (default: "png") - Either `png`, `jpeg`, or `webp`
|
||||||
|
* `advanced_params`: string (optional) - JSON-encoded advanced parameters
|
||||||
|
|
||||||
**Response:**
|
**Response:**
|
||||||
OpenAI-compatible image generation response.
|
Same format as `/v1/images/generations`.
|
||||||
|
|
||||||
### Benchmarked Image Generation
|
### Benchmarked Image Generation
|
||||||
|
|
||||||
@@ -244,16 +585,15 @@ Same as `/v1/images/generations`, but also returns generation statistics.
|
|||||||
Same schema as `/v1/images/generations`.
|
Same schema as `/v1/images/generations`.
|
||||||
|
|
||||||
**Response:**
|
**Response:**
|
||||||
Image generation plus benchmarking metrics.
|
Image generation plus benchmarking metrics including:
|
||||||
|
|
||||||
### Image Editing
|
* `seconds_per_step` - Average time per denoising step
|
||||||
|
* `total_generation_time` - Total generation time
|
||||||
**POST** `/v1/images/edits`
|
* `num_inference_steps` - Number of inference steps used
|
||||||
|
* `num_images` - Number of images generated
|
||||||
Executes an image editing request using an OpenAI-compatible schema with additional advanced_params (same as `/v1/images/generations`).
|
* `image_width` - Output image width
|
||||||
|
* `image_height` - Output image height
|
||||||
**Response:**
|
* `peak_memory_usage` - Peak memory used during generation
|
||||||
Same format as `/v1/images/generations`.
|
|
||||||
|
|
||||||
### Benchmarked Image Editing
|
### Benchmarked Image Editing
|
||||||
|
|
||||||
@@ -267,37 +607,127 @@ Same schema as `/v1/images/edits`.
|
|||||||
**Response:**
|
**Response:**
|
||||||
Same format as `/bench/images/generations`, including `generation_stats`.
|
Same format as `/bench/images/generations`, including `generation_stats`.
|
||||||
|
|
||||||
## 6. Complete Endpoint Summary
|
### List Images
|
||||||
|
|
||||||
|
**GET** `/images`
|
||||||
|
|
||||||
|
List all stored images.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Array of image metadata including URLs and expiration times.
|
||||||
|
|
||||||
|
### Get Image
|
||||||
|
|
||||||
|
**GET** `/images/{image_id}`
|
||||||
|
|
||||||
|
Retrieve a stored image by ID.
|
||||||
|
|
||||||
|
**Path parameters:**
|
||||||
|
|
||||||
|
* `image_id`: string, ID of the image
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
Image file with appropriate content type.
|
||||||
|
|
||||||
|
## 7. Complete Endpoint Summary
|
||||||
|
|
||||||
```
|
```
|
||||||
|
# General
|
||||||
GET /node_id
|
GET /node_id
|
||||||
GET /state
|
GET /state
|
||||||
GET /events
|
GET /events
|
||||||
|
|
||||||
|
# Instance Management
|
||||||
POST /instance
|
POST /instance
|
||||||
GET /instance/{instance_id}
|
GET /instance/{instance_id}
|
||||||
DELETE /instance/{instance_id}
|
DELETE /instance/{instance_id}
|
||||||
|
|
||||||
GET /instance/previews
|
GET /instance/previews
|
||||||
GET /instance/placement
|
GET /instance/placement
|
||||||
POST /place_instance
|
POST /place_instance
|
||||||
|
|
||||||
|
# Models
|
||||||
GET /models
|
GET /models
|
||||||
GET /v1/models
|
GET /v1/models
|
||||||
|
POST /models/add
|
||||||
|
DELETE /models/custom/{model_id}
|
||||||
|
GET /models/search
|
||||||
|
|
||||||
|
# Text Generation (OpenAI Chat Completions)
|
||||||
POST /v1/chat/completions
|
POST /v1/chat/completions
|
||||||
POST /bench/chat/completions
|
POST /bench/chat/completions
|
||||||
|
|
||||||
|
# Text Generation (Claude Messages API)
|
||||||
|
POST /v1/messages
|
||||||
|
|
||||||
|
# Text Generation (OpenAI Responses API)
|
||||||
|
POST /v1/responses
|
||||||
|
|
||||||
|
# Text Generation (Ollama API)
|
||||||
|
POST /ollama/api/chat
|
||||||
|
POST /ollama/api/api/chat
|
||||||
|
POST /ollama/api/v1/chat
|
||||||
|
POST /ollama/api/generate
|
||||||
|
GET /ollama/api/tags
|
||||||
|
GET /ollama/api/api/tags
|
||||||
|
GET /ollama/api/v1/tags
|
||||||
|
POST /ollama/api/show
|
||||||
|
GET /ollama/api/ps
|
||||||
|
GET /ollama/api/version
|
||||||
|
HEAD /ollama/
|
||||||
|
HEAD /ollama/api/version
|
||||||
|
|
||||||
|
# Command Control
|
||||||
POST /v1/cancel/{command_id}
|
POST /v1/cancel/{command_id}
|
||||||
|
|
||||||
|
# Image Generation
|
||||||
POST /v1/images/generations
|
POST /v1/images/generations
|
||||||
POST /bench/images/generations
|
POST /bench/images/generations
|
||||||
POST /v1/images/edits
|
POST /v1/images/edits
|
||||||
POST /bench/images/edits
|
POST /bench/images/edits
|
||||||
|
GET /images
|
||||||
|
GET /images/{image_id}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 7. Notes
|
## 8. Notes
|
||||||
|
|
||||||
* The `/v1/chat/completions` endpoint is compatible with the OpenAI Chat API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
|
### API Compatibility
|
||||||
* The `/v1/images/generations` and `/v1/images/edits` endpoints are compatible with the OpenAI Images API format.
|
|
||||||
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
|
EXO provides multiple API-compatible interfaces:
|
||||||
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.
|
|
||||||
|
* **OpenAI Chat Completions API** - Compatible with OpenAI clients and tools
|
||||||
|
* **Claude Messages API** - Compatible with Anthropic's Claude API format
|
||||||
|
* **OpenAI Responses API** - Compatible with OpenAI's Responses API format
|
||||||
|
* **Ollama API** - Compatible with Ollama and tools like OpenWebUI
|
||||||
|
|
||||||
|
Existing OpenAI, Claude, or Ollama clients can be pointed to EXO by changing the base URL.
|
||||||
|
|
||||||
|
### Custom Models
|
||||||
|
|
||||||
|
You can add custom models from HuggingFace using the `/models/add` endpoint. Custom models are identified by the `is_custom` field in model list responses.
|
||||||
|
|
||||||
|
**Security:** Models requiring `trust_remote_code` must be explicitly enabled (default is false) for security. Only enable this if you trust the model's remote code.
|
||||||
|
|
||||||
|
### Usage Statistics
|
||||||
|
|
||||||
|
Chat completion responses include usage statistics with:
|
||||||
|
|
||||||
|
* `prompt_tokens` - Number of tokens in the prompt
|
||||||
|
* `completion_tokens` - Number of tokens generated
|
||||||
|
* `total_tokens` - Sum of prompt and completion tokens
|
||||||
|
|
||||||
|
### Request Cancellation
|
||||||
|
|
||||||
|
You can cancel active requests by:
|
||||||
|
|
||||||
|
1. Closing the HTTP connection (for streaming requests)
|
||||||
|
2. Calling `/v1/cancel/{command_id}` (for any request)
|
||||||
|
|
||||||
|
The server detects cancellation and stops processing immediately.
|
||||||
|
|
||||||
|
### Instance Placement
|
||||||
|
|
||||||
|
The instance placement endpoints allow you to plan and preview cluster allocations before creating instances. This helps optimize resource usage across nodes.
|
||||||
|
|
||||||
|
### Observability
|
||||||
|
|
||||||
|
The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.
|
||||||
|
|||||||
@@ -28,6 +28,26 @@ There are currently 5 major systems:
|
|||||||
|
|
||||||
Implements a distributed algorithm for master election in unstable networking conditions
|
Implements a distributed algorithm for master election in unstable networking conditions
|
||||||
|
|
||||||
|
## API Layer
|
||||||
|
|
||||||
|
The API system uses multiple adapters to support multiple API formats, converting them to a single request / response type.
|
||||||
|
|
||||||
|
### Adapter Pattern
|
||||||
|
|
||||||
|
Adapters convert between external API formats and EXO's internal types:
|
||||||
|
|
||||||
|
```
|
||||||
|
Chat Completions → [adapter] → TextGenerationTaskParams → Application
|
||||||
|
Claude Messages → [adapter] → TextGenerationTaskParams → Application
|
||||||
|
Responses API → [adapter] → TextGenerationTaskParams → Application
|
||||||
|
Ollama API → [adapter] → TextGenerationTaskParams → Application
|
||||||
|
```
|
||||||
|
|
||||||
|
Each adapter implements two key functions:
|
||||||
|
1. **Request conversion**: Converts API-specific requests to `TextGenerationTaskParams`
|
||||||
|
2. **Response generation**: Converts internal `TokenChunk` streams back to API-specific formats (streaming and non-streaming)
|
||||||
|
|
||||||
|
|
||||||
## Topics
|
## Topics
|
||||||
|
|
||||||
There are currently 5 topics:
|
There are currently 5 topics:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
|
from exo_pyo3_bindings import (
|
||||||
|
Keypair,
|
||||||
|
NetworkingHandle,
|
||||||
|
NoPeersSubscribedToTopicError,
|
||||||
|
PyFromSwarm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sleep_on_multiple_items() -> None:
|
async def test_sleep_on_multiple_items() -> None:
|
||||||
print("PYTHON: starting handle")
|
print("PYTHON: starting handle")
|
||||||
h = NetworkingHandle(Keypair.generate_ed25519())
|
h = NetworkingHandle(Keypair.generate())
|
||||||
|
|
||||||
ct = asyncio.create_task(_await_cons(h))
|
rt = asyncio.create_task(_await_recv(h))
|
||||||
mt = asyncio.create_task(_await_msg(h))
|
|
||||||
|
|
||||||
# sleep for 4 ticks
|
# sleep for 4 ticks
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
@@ -22,13 +26,11 @@ async def test_sleep_on_multiple_items() -> None:
|
|||||||
print("caught it", e)
|
print("caught it", e)
|
||||||
|
|
||||||
|
|
||||||
async def _await_cons(h: NetworkingHandle):
|
async def _await_recv(h: NetworkingHandle):
|
||||||
while True:
|
while True:
|
||||||
c = await h.connection_update_recv()
|
event = await h.recv()
|
||||||
print(f"PYTHON: connection update: {c}")
|
match event:
|
||||||
|
case PyFromSwarm.Connection() as c:
|
||||||
|
print(f"PYTHON: connection update: {c}")
|
||||||
async def _await_msg(h: NetworkingHandle):
|
case PyFromSwarm.Message() as m:
|
||||||
while True:
|
print(f"PYTHON: message: {m}")
|
||||||
m = await h.gossipsub_recv()
|
|
||||||
print(f"PYTHON: message: {m}")
|
|
||||||
|
|||||||
@@ -19,9 +19,7 @@ def exo_shard_downloader(
|
|||||||
max_parallel_downloads: int = 8, offline: bool = False
|
max_parallel_downloads: int = 8, offline: bool = False
|
||||||
) -> ShardDownloader:
|
) -> ShardDownloader:
|
||||||
return SingletonShardDownloader(
|
return SingletonShardDownloader(
|
||||||
CachedShardDownloader(
|
ResumableShardDownloader(max_parallel_downloads, offline=offline)
|
||||||
ResumableShardDownloader(max_parallel_downloads, offline=offline)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -85,39 +83,6 @@ class SingletonShardDownloader(ShardDownloader):
|
|||||||
return await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
return await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
||||||
|
|
||||||
|
|
||||||
class CachedShardDownloader(ShardDownloader):
|
|
||||||
def __init__(self, shard_downloader: ShardDownloader):
|
|
||||||
self.shard_downloader = shard_downloader
|
|
||||||
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
|
|
||||||
|
|
||||||
def on_progress(
|
|
||||||
self,
|
|
||||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
|
||||||
) -> None:
|
|
||||||
self.shard_downloader.on_progress(callback)
|
|
||||||
|
|
||||||
async def ensure_shard(
|
|
||||||
self, shard: ShardMetadata, config_only: bool = False
|
|
||||||
) -> Path:
|
|
||||||
if (shard.model_card.model_id, shard) in self.cache:
|
|
||||||
return self.cache[(shard.model_card.model_id, shard)]
|
|
||||||
|
|
||||||
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
|
|
||||||
self.cache[(shard.model_card.model_id, shard)] = target_dir
|
|
||||||
return target_dir
|
|
||||||
|
|
||||||
async def get_shard_download_status(
|
|
||||||
self,
|
|
||||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
|
||||||
async for path, status in self.shard_downloader.get_shard_download_status():
|
|
||||||
yield path, status
|
|
||||||
|
|
||||||
async def get_shard_download_status_for_shard(
|
|
||||||
self, shard: ShardMetadata
|
|
||||||
) -> RepoDownloadProgress:
|
|
||||||
return await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
|
||||||
|
|
||||||
|
|
||||||
class ResumableShardDownloader(ShardDownloader):
|
class ResumableShardDownloader(ShardDownloader):
|
||||||
def __init__(self, max_parallel_downloads: int = 8, offline: bool = False):
|
def __init__(self, max_parallel_downloads: int = 8, offline: bool = False):
|
||||||
self.max_parallel_downloads = max_parallel_downloads
|
self.max_parallel_downloads = max_parallel_downloads
|
||||||
|
|||||||
211
src/exo/download/tests/test_re_download.py
Normal file
211
src/exo/download/tests/test_re_download.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""Tests that re-downloading a previously deleted model completes successfully."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
from collections.abc import AsyncIterator, Awaitable
|
||||||
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from exo.download.coordinator import DownloadCoordinator
|
||||||
|
from exo.download.download_utils import RepoDownloadProgress
|
||||||
|
from exo.download.impl_shard_downloader import SingletonShardDownloader
|
||||||
|
from exo.download.shard_downloader import ShardDownloader
|
||||||
|
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||||
|
from exo.shared.types.commands import (
|
||||||
|
DeleteDownload,
|
||||||
|
ForwarderDownloadCommand,
|
||||||
|
StartDownload,
|
||||||
|
)
|
||||||
|
from exo.shared.types.common import NodeId, SystemId
|
||||||
|
from exo.shared.types.events import Event, NodeDownloadProgress
|
||||||
|
from exo.shared.types.memory import Memory
|
||||||
|
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||||
|
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||||
|
from exo.utils.channels import Receiver, Sender, channel
|
||||||
|
|
||||||
|
NODE_ID = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||||
|
MODEL_ID = ModelId("test-org/test-model")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_shard(model_id: ModelId = MODEL_ID) -> ShardMetadata:
|
||||||
|
return PipelineShardMetadata(
|
||||||
|
model_card=ModelCard(
|
||||||
|
model_id=model_id,
|
||||||
|
storage_size=Memory.from_mb(100),
|
||||||
|
n_layers=28,
|
||||||
|
hidden_size=1024,
|
||||||
|
supports_tensor=False,
|
||||||
|
tasks=[ModelTask.TextGeneration],
|
||||||
|
),
|
||||||
|
device_rank=0,
|
||||||
|
world_size=1,
|
||||||
|
start_layer=0,
|
||||||
|
end_layer=28,
|
||||||
|
n_layers=28,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeShardDownloader(ShardDownloader):
|
||||||
|
"""Fake downloader that simulates a successful download by firing the
|
||||||
|
progress callback with status='complete' when ensure_shard is called."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._progress_callbacks: list[
|
||||||
|
Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]]
|
||||||
|
] = []
|
||||||
|
|
||||||
|
def on_progress(
|
||||||
|
self,
|
||||||
|
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||||
|
) -> None:
|
||||||
|
self._progress_callbacks.append(callback)
|
||||||
|
|
||||||
|
async def ensure_shard(
|
||||||
|
self,
|
||||||
|
shard: ShardMetadata,
|
||||||
|
config_only: bool = False, # noqa: ARG002
|
||||||
|
) -> Path:
|
||||||
|
# Simulate a completed download by firing the progress callback
|
||||||
|
progress = RepoDownloadProgress(
|
||||||
|
repo_id=str(shard.model_card.model_id),
|
||||||
|
repo_revision="main",
|
||||||
|
shard=shard,
|
||||||
|
completed_files=1,
|
||||||
|
total_files=1,
|
||||||
|
downloaded=Memory.from_mb(100),
|
||||||
|
downloaded_this_session=Memory.from_mb(100),
|
||||||
|
total=Memory.from_mb(100),
|
||||||
|
overall_speed=0,
|
||||||
|
overall_eta=timedelta(seconds=0),
|
||||||
|
status="complete",
|
||||||
|
)
|
||||||
|
for cb in self._progress_callbacks:
|
||||||
|
await cb(shard, progress)
|
||||||
|
return Path("/fake/models") / shard.model_card.model_id.normalize()
|
||||||
|
|
||||||
|
async def get_shard_download_status(
|
||||||
|
self,
|
||||||
|
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||||
|
if False: # noqa: SIM108 # empty async generator
|
||||||
|
yield (
|
||||||
|
Path(),
|
||||||
|
RepoDownloadProgress( # pyright: ignore[reportUnreachable]
|
||||||
|
repo_id="",
|
||||||
|
repo_revision="",
|
||||||
|
shard=_make_shard(),
|
||||||
|
completed_files=0,
|
||||||
|
total_files=0,
|
||||||
|
downloaded=Memory.from_bytes(0),
|
||||||
|
downloaded_this_session=Memory.from_bytes(0),
|
||||||
|
total=Memory.from_bytes(0),
|
||||||
|
overall_speed=0,
|
||||||
|
overall_eta=timedelta(seconds=0),
|
||||||
|
status="not_started",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_shard_download_status_for_shard(
|
||||||
|
self,
|
||||||
|
shard: ShardMetadata,
|
||||||
|
) -> RepoDownloadProgress:
|
||||||
|
return RepoDownloadProgress(
|
||||||
|
repo_id=str(shard.model_card.model_id),
|
||||||
|
repo_revision="main",
|
||||||
|
shard=shard,
|
||||||
|
completed_files=0,
|
||||||
|
total_files=1,
|
||||||
|
downloaded=Memory.from_bytes(0),
|
||||||
|
downloaded_this_session=Memory.from_bytes(0),
|
||||||
|
total=Memory.from_mb(100),
|
||||||
|
overall_speed=0,
|
||||||
|
overall_eta=timedelta(seconds=0),
|
||||||
|
status="not_started",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_re_download_after_delete_completes() -> None:
|
||||||
|
"""A model that was downloaded, deleted, and then re-downloaded should
|
||||||
|
reach DownloadCompleted status. This is an end-to-end test through
|
||||||
|
the DownloadCoordinator."""
|
||||||
|
cmd_send: Sender[ForwarderDownloadCommand]
|
||||||
|
cmd_send, cmd_recv = channel[ForwarderDownloadCommand]()
|
||||||
|
event_send, event_recv = channel[Event]()
|
||||||
|
|
||||||
|
fake_downloader = FakeShardDownloader()
|
||||||
|
wrapped_downloader = SingletonShardDownloader(fake_downloader)
|
||||||
|
coordinator = DownloadCoordinator(
|
||||||
|
node_id=NODE_ID,
|
||||||
|
shard_downloader=wrapped_downloader,
|
||||||
|
download_command_receiver=cmd_recv,
|
||||||
|
event_sender=event_send,
|
||||||
|
)
|
||||||
|
|
||||||
|
shard = _make_shard()
|
||||||
|
origin = SystemId("test")
|
||||||
|
|
||||||
|
with patch("exo.download.coordinator.delete_model", new_callable=AsyncMock):
|
||||||
|
# Run the coordinator in the background
|
||||||
|
coordinator_task = asyncio.create_task(coordinator.run())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Start first download
|
||||||
|
await cmd_send.send(
|
||||||
|
ForwarderDownloadCommand(
|
||||||
|
origin=origin,
|
||||||
|
command=StartDownload(target_node_id=NODE_ID, shard_metadata=shard),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for DownloadCompleted
|
||||||
|
first_completed = await _wait_for_download_completed(event_recv, MODEL_ID)
|
||||||
|
assert first_completed is not None, "First download should complete"
|
||||||
|
|
||||||
|
# 2. Delete the model
|
||||||
|
await cmd_send.send(
|
||||||
|
ForwarderDownloadCommand(
|
||||||
|
origin=origin,
|
||||||
|
command=DeleteDownload(target_node_id=NODE_ID, model_id=MODEL_ID),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Give the coordinator time to process the delete
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
# 3. Re-download the same model
|
||||||
|
await cmd_send.send(
|
||||||
|
ForwarderDownloadCommand(
|
||||||
|
origin=origin,
|
||||||
|
command=StartDownload(target_node_id=NODE_ID, shard_metadata=shard),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for second DownloadCompleted — this is the bug: it never arrives
|
||||||
|
second_completed = await _wait_for_download_completed(event_recv, MODEL_ID)
|
||||||
|
assert second_completed is not None, (
|
||||||
|
"Re-download after deletion should complete"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
coordinator.shutdown()
|
||||||
|
coordinator_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await coordinator_task
|
||||||
|
|
||||||
|
|
||||||
|
async def _wait_for_download_completed(
|
||||||
|
event_recv: Receiver[Event], model_id: ModelId, timeout: float = 2.0
|
||||||
|
) -> DownloadCompleted | None:
|
||||||
|
"""Drain events until we see a DownloadCompleted for the given model, or timeout."""
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(timeout):
|
||||||
|
while True:
|
||||||
|
event = await event_recv.receive()
|
||||||
|
if (
|
||||||
|
isinstance(event, NodeDownloadProgress)
|
||||||
|
and isinstance(event.download_progress, DownloadCompleted)
|
||||||
|
and event.download_progress.shard_metadata.model_card.model_id
|
||||||
|
== model_id
|
||||||
|
):
|
||||||
|
return event.download_progress
|
||||||
|
except TimeoutError:
|
||||||
|
return None
|
||||||
@@ -26,7 +26,11 @@ from exo.shared.types.chunks import (
|
|||||||
ToolCallChunk,
|
ToolCallChunk,
|
||||||
)
|
)
|
||||||
from exo.shared.types.common import CommandId
|
from exo.shared.types.common import CommandId
|
||||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
from exo.shared.types.text_generation import (
|
||||||
|
InputMessage,
|
||||||
|
TextGenerationTaskParams,
|
||||||
|
resolve_reasoning_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def chat_request_to_text_generation(
|
def chat_request_to_text_generation(
|
||||||
@@ -75,6 +79,10 @@ def chat_request_to_text_generation(
|
|||||||
dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True)
|
dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True)
|
||||||
chat_template_messages.append(dumped)
|
chat_template_messages.append(dumped)
|
||||||
|
|
||||||
|
resolved_effort, resolved_thinking = resolve_reasoning_params(
|
||||||
|
request.reasoning_effort, request.enable_thinking
|
||||||
|
)
|
||||||
|
|
||||||
return TextGenerationTaskParams(
|
return TextGenerationTaskParams(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
input=input_messages
|
input=input_messages
|
||||||
@@ -89,12 +97,15 @@ def chat_request_to_text_generation(
|
|||||||
seed=request.seed,
|
seed=request.seed,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
tools=request.tools,
|
tools=request.tools,
|
||||||
enable_thinking=request.enable_thinking,
|
reasoning_effort=resolved_effort,
|
||||||
|
enable_thinking=resolved_thinking,
|
||||||
chat_template_messages=chat_template_messages
|
chat_template_messages=chat_template_messages
|
||||||
if chat_template_messages
|
if chat_template_messages
|
||||||
else None,
|
else None,
|
||||||
logprobs=request.logprobs or False,
|
logprobs=request.logprobs or False,
|
||||||
top_logprobs=request.top_logprobs,
|
top_logprobs=request.top_logprobs,
|
||||||
|
repetition_penalty=request.repetition_penalty,
|
||||||
|
repetition_context_size=request.repetition_context_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,11 @@ from exo.shared.types.openai_responses import (
|
|||||||
ResponseTextDoneEvent,
|
ResponseTextDoneEvent,
|
||||||
ResponseUsage,
|
ResponseUsage,
|
||||||
)
|
)
|
||||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
from exo.shared.types.text_generation import (
|
||||||
|
InputMessage,
|
||||||
|
TextGenerationTaskParams,
|
||||||
|
resolve_reasoning_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _format_sse(event: ResponsesStreamEvent) -> str:
|
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||||
@@ -119,6 +123,11 @@ def responses_request_to_text_generation(
|
|||||||
)
|
)
|
||||||
built_chat_template = chat_template_messages if chat_template_messages else None
|
built_chat_template = chat_template_messages if chat_template_messages else None
|
||||||
|
|
||||||
|
effort_from_reasoning = request.reasoning.effort if request.reasoning else None
|
||||||
|
resolved_effort, resolved_thinking = resolve_reasoning_params(
|
||||||
|
effort_from_reasoning, request.enable_thinking
|
||||||
|
)
|
||||||
|
|
||||||
return TextGenerationTaskParams(
|
return TextGenerationTaskParams(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
input=input_value,
|
input=input_value,
|
||||||
@@ -132,6 +141,8 @@ def responses_request_to_text_generation(
|
|||||||
stop=request.stop,
|
stop=request.stop,
|
||||||
seed=request.seed,
|
seed=request.seed,
|
||||||
chat_template_messages=built_chat_template or request.chat_template_messages,
|
chat_template_messages=built_chat_template or request.chat_template_messages,
|
||||||
|
reasoning_effort=resolved_effort,
|
||||||
|
enable_thinking=resolved_thinking,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import contextlib
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator
|
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -775,7 +775,7 @@ class API:
|
|||||||
return resolved_model
|
return resolved_model
|
||||||
|
|
||||||
def stream_events(self) -> StreamingResponse:
|
def stream_events(self) -> StreamingResponse:
|
||||||
def _generate_json_array(events: Iterator[Event]) -> Iterator[str]:
|
def _generate_json_array(events: Iterable[Event]) -> Iterable[str]:
|
||||||
yield "["
|
yield "["
|
||||||
first = True
|
first = True
|
||||||
for event in events:
|
for event in events:
|
||||||
@@ -1586,26 +1586,42 @@ class API:
|
|||||||
async def search_models(
|
async def search_models(
|
||||||
self, query: str = "", limit: int = 20
|
self, query: str = "", limit: int = 20
|
||||||
) -> list[HuggingFaceSearchResult]:
|
) -> list[HuggingFaceSearchResult]:
|
||||||
"""Search HuggingFace Hub for mlx-community models."""
|
"""Search HuggingFace Hub — tries mlx-community first, falls back to all of HuggingFace."""
|
||||||
from huggingface_hub import list_models
|
from huggingface_hub import ModelInfo, list_models
|
||||||
|
|
||||||
results = list_models(
|
def _to_results(models: Iterable[ModelInfo]) -> list[HuggingFaceSearchResult]:
|
||||||
search=query or None,
|
return [
|
||||||
author="mlx-community",
|
HuggingFaceSearchResult(
|
||||||
sort="downloads",
|
id=m.id,
|
||||||
limit=limit,
|
author=m.author or "",
|
||||||
)
|
downloads=m.downloads or 0,
|
||||||
return [
|
likes=m.likes or 0,
|
||||||
HuggingFaceSearchResult(
|
last_modified=str(m.last_modified or ""),
|
||||||
id=m.id,
|
tags=list(m.tags or []),
|
||||||
author=m.author or "",
|
)
|
||||||
downloads=m.downloads or 0,
|
for m in models
|
||||||
likes=m.likes or 0,
|
]
|
||||||
last_modified=str(m.last_modified or ""),
|
|
||||||
tags=list(m.tags or []),
|
# Search mlx-community first
|
||||||
|
mlx_results = _to_results(
|
||||||
|
list_models(
|
||||||
|
search=query or None,
|
||||||
|
author="mlx-community",
|
||||||
|
sort="downloads",
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
for m in results
|
)
|
||||||
]
|
if mlx_results:
|
||||||
|
return mlx_results
|
||||||
|
|
||||||
|
# Fall back to searching all of HuggingFace
|
||||||
|
return _to_results(
|
||||||
|
list_models(
|
||||||
|
search=query or None,
|
||||||
|
sort="downloads",
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
shutdown_ev = anyio.Event()
|
shutdown_ev = anyio.Event()
|
||||||
|
|||||||
@@ -328,17 +328,22 @@ class Master:
|
|||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
case TaskFinished():
|
else:
|
||||||
generated_events.append(
|
logger.warning(
|
||||||
TaskDeleted(
|
f"Nonexistent command {command.cancelled_command_id} cancelled"
|
||||||
task_id=self.command_task_mapping[
|
|
||||||
command.finished_command_id
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
)
|
case TaskFinished():
|
||||||
self.command_task_mapping.pop(
|
if (
|
||||||
command.finished_command_id, None
|
task_id := self.command_task_mapping.pop(
|
||||||
)
|
command.finished_command_id, None
|
||||||
|
)
|
||||||
|
) is not None:
|
||||||
|
generated_events.append(TaskDeleted(task_id=task_id))
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Finished command {command.finished_command_id} finished"
|
||||||
|
)
|
||||||
|
|
||||||
case RequestEventLog():
|
case RequestEventLog():
|
||||||
# We should just be able to send everything, since other buffers will ignore old messages
|
# We should just be able to send everything, since other buffers will ignore old messages
|
||||||
# rate limit to 1000 at a time
|
# rate limit to 1000 at a time
|
||||||
|
|||||||
@@ -258,6 +258,6 @@ def get_node_id_keypair(
|
|||||||
|
|
||||||
# if no valid credentials, create new ones and persist
|
# if no valid credentials, create new ones and persist
|
||||||
with open(path, "w+b") as f:
|
with open(path, "w+b") as f:
|
||||||
keypair = Keypair.generate_ed25519()
|
keypair = Keypair.generate()
|
||||||
f.write(keypair.to_bytes())
|
f.write(keypair.to_bytes())
|
||||||
return keypair
|
return keypair
|
||||||
|
|||||||
30
src/exo/shared/tests/test_resolve_reasoning_params.py
Normal file
30
src/exo/shared/tests/test_resolve_reasoning_params.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from exo.shared.types.text_generation import resolve_reasoning_params
|
||||||
|
|
||||||
|
|
||||||
|
def test_both_none_returns_none_none() -> None:
|
||||||
|
assert resolve_reasoning_params(None, None) == (None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_both_set_passes_through_unchanged() -> None:
|
||||||
|
assert resolve_reasoning_params("high", True) == ("high", True)
|
||||||
|
assert resolve_reasoning_params("none", True) == ("none", True)
|
||||||
|
assert resolve_reasoning_params("low", False) == ("low", False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enable_thinking_true_derives_medium() -> None:
|
||||||
|
assert resolve_reasoning_params(None, True) == ("medium", True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enable_thinking_false_derives_none() -> None:
|
||||||
|
assert resolve_reasoning_params(None, False) == ("none", False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reasoning_effort_none_derives_thinking_false() -> None:
|
||||||
|
assert resolve_reasoning_params("none", None) == ("none", False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("effort", ["minimal", "low", "medium", "high", "xhigh"])
|
||||||
|
def test_non_none_effort_derives_thinking_true(effort: str) -> None:
|
||||||
|
assert resolve_reasoning_params(effort, None) == (effort, True) # pyright: ignore[reportArgumentType]
|
||||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||||
from exo.shared.types.common import CommandId, NodeId
|
from exo.shared.types.common import CommandId, NodeId
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
|
from exo.shared.types.text_generation import ReasoningEffort
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel
|
from exo.utils.pydantic_ext import CamelCaseModel
|
||||||
@@ -198,7 +199,10 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
top_p: float | None = None
|
top_p: float | None = None
|
||||||
top_k: int | None = None
|
top_k: int | None = None
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
reasoning_effort: ReasoningEffort | None = None
|
||||||
enable_thinking: bool | None = None
|
enable_thinking: bool | None = None
|
||||||
|
repetition_penalty: float | None = None
|
||||||
|
repetition_context_size: int | None = None
|
||||||
tool_choice: str | dict[str, Any] | None = None
|
tool_choice: str | dict[str, Any] | None = None
|
||||||
parallel_tool_calls: bool | None = None
|
parallel_tool_calls: bool | None = None
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from typing import Any, Literal
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from exo.shared.types.common import ModelId
|
from exo.shared.types.common import ModelId
|
||||||
|
from exo.shared.types.text_generation import ReasoningEffort
|
||||||
|
|
||||||
# Type aliases
|
# Type aliases
|
||||||
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
|
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
|
||||||
@@ -71,6 +72,13 @@ ResponseInputItem = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Reasoning(BaseModel, frozen=True):
|
||||||
|
"""Reasoning configuration for OpenAI Responses API."""
|
||||||
|
|
||||||
|
effort: ReasoningEffort | None = None
|
||||||
|
summary: Literal["auto", "concise", "detailed"] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ResponsesRequest(BaseModel, frozen=True):
|
class ResponsesRequest(BaseModel, frozen=True):
|
||||||
"""Request body for OpenAI Responses API.
|
"""Request body for OpenAI Responses API.
|
||||||
|
|
||||||
@@ -89,8 +97,15 @@ class ResponsesRequest(BaseModel, frozen=True):
|
|||||||
stream: bool = False
|
stream: bool = False
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
metadata: dict[str, str] | None = None
|
metadata: dict[str, str] | None = None
|
||||||
|
reasoning: Reasoning | None = None
|
||||||
|
|
||||||
# --- exo extensions (not in OpenAI Responses API spec) ---
|
# --- exo extensions (not in OpenAI Responses API spec) ---
|
||||||
|
enable_thinking: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[exo extension] Boolean thinking toggle. Not part of the OpenAI Responses API.",
|
||||||
|
json_schema_extra={"x-exo-extension": True},
|
||||||
|
)
|
||||||
|
|
||||||
top_k: int | None = Field(
|
top_k: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="[exo extension] Top-k sampling parameter. Not part of the OpenAI Responses API.",
|
description="[exo extension] Top-k sampling parameter. Not part of the OpenAI Responses API.",
|
||||||
|
|||||||
@@ -11,6 +11,29 @@ from pydantic import BaseModel
|
|||||||
from exo.shared.types.common import ModelId
|
from exo.shared.types.common import ModelId
|
||||||
|
|
||||||
MessageRole = Literal["user", "assistant", "system", "developer"]
|
MessageRole = Literal["user", "assistant", "system", "developer"]
|
||||||
|
ReasoningEffort = Literal["none", "minimal", "low", "medium", "high", "xhigh"]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_reasoning_params(
|
||||||
|
reasoning_effort: ReasoningEffort | None,
|
||||||
|
enable_thinking: bool | None,
|
||||||
|
) -> tuple[ReasoningEffort | None, bool | None]:
|
||||||
|
"""
|
||||||
|
enable_thinking=True -> reasoning_effort="medium"
|
||||||
|
enable_thinking=False -> reasoning_effort="none"
|
||||||
|
reasoning_effort="none" -> enable_thinking=False
|
||||||
|
reasoning_effort=<anything else> -> enable_thinking=True
|
||||||
|
"""
|
||||||
|
resolved_effort: ReasoningEffort | None = reasoning_effort
|
||||||
|
resolved_thinking: bool | None = enable_thinking
|
||||||
|
|
||||||
|
if reasoning_effort is None and enable_thinking is not None:
|
||||||
|
resolved_effort = "medium" if enable_thinking else "none"
|
||||||
|
|
||||||
|
if enable_thinking is None and reasoning_effort is not None:
|
||||||
|
resolved_thinking = reasoning_effort != "none"
|
||||||
|
|
||||||
|
return resolved_effort, resolved_thinking
|
||||||
|
|
||||||
|
|
||||||
class InputMessage(BaseModel, frozen=True):
|
class InputMessage(BaseModel, frozen=True):
|
||||||
@@ -40,6 +63,9 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
|
|||||||
stop: str | list[str] | None = None
|
stop: str | list[str] | None = None
|
||||||
seed: int | None = None
|
seed: int | None = None
|
||||||
chat_template_messages: list[dict[str, Any]] | None = None
|
chat_template_messages: list[dict[str, Any]] | None = None
|
||||||
|
reasoning_effort: ReasoningEffort | None = None
|
||||||
enable_thinking: bool | None = None
|
enable_thinking: bool | None = None
|
||||||
logprobs: bool = False
|
logprobs: bool = False
|
||||||
top_logprobs: int | None = None
|
top_logprobs: int | None = None
|
||||||
|
repetition_penalty: float | None = None
|
||||||
|
repetition_context_size: int | None = None
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from mlx_lm.generate import (
|
|||||||
stream_generate,
|
stream_generate,
|
||||||
)
|
)
|
||||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||||
from mlx_lm.sample_utils import make_sampler
|
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
from exo.shared.types.api import (
|
from exo.shared.types.api import (
|
||||||
@@ -437,6 +437,7 @@ def mlx_generate(
|
|||||||
group: mx.distributed.Group | None,
|
group: mx.distributed.Group | None,
|
||||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||||
distributed_prompt_progress_callback: Callable[[], None] | None = None,
|
distributed_prompt_progress_callback: Callable[[], None] | None = None,
|
||||||
|
on_generation_token: Callable[[], None] | None = None,
|
||||||
) -> Generator[GenerationResponse]:
|
) -> Generator[GenerationResponse]:
|
||||||
# Ensure that generation stats only contains peak memory for this generation
|
# Ensure that generation stats only contains peak memory for this generation
|
||||||
mx.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
@@ -469,11 +470,16 @@ def mlx_generate(
|
|||||||
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
|
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
|
||||||
)
|
)
|
||||||
|
|
||||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = (
|
||||||
|
make_logits_processors(
|
||||||
|
repetition_penalty=task.repetition_penalty,
|
||||||
|
repetition_context_size=task.repetition_context_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
if is_bench:
|
if is_bench:
|
||||||
# Only sample length eos tokens
|
# Only sample length eos tokens
|
||||||
eos_ids = eos_ids_from_tokenizer(tokenizer)
|
eos_ids = eos_ids_from_tokenizer(tokenizer)
|
||||||
logits_processors = [ban_token_ids(eos_ids)]
|
logits_processors = [ban_token_ids(eos_ids)] + logits_processors
|
||||||
|
|
||||||
sampler = make_sampler(
|
sampler = make_sampler(
|
||||||
temp=task.temperature if task.temperature is not None else 0.7,
|
temp=task.temperature if task.temperature is not None else 0.7,
|
||||||
@@ -644,6 +650,9 @@ def mlx_generate(
|
|||||||
full_prompt_tokens, caches, cache_snapshots
|
full_prompt_tokens, caches, cache_snapshots
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if on_generation_token is not None:
|
||||||
|
on_generation_token()
|
||||||
|
|
||||||
yield GenerationResponse(
|
yield GenerationResponse(
|
||||||
text=text,
|
text=text,
|
||||||
token=out.token,
|
token=out.token,
|
||||||
|
|||||||
@@ -554,6 +554,8 @@ def apply_chat_template(
|
|||||||
# Jinja ignores unknown variables, so passing both is safe.
|
# Jinja ignores unknown variables, so passing both is safe.
|
||||||
extra_kwargs["enable_thinking"] = task_params.enable_thinking
|
extra_kwargs["enable_thinking"] = task_params.enable_thinking
|
||||||
extra_kwargs["thinking"] = task_params.enable_thinking
|
extra_kwargs["thinking"] = task_params.enable_thinking
|
||||||
|
if task_params.reasoning_effort is not None:
|
||||||
|
extra_kwargs["reasoning_effort"] = task_params.reasoning_effort
|
||||||
|
|
||||||
patched_template: str | None = None
|
patched_template: str | None = None
|
||||||
if task_params.tools:
|
if task_params.tools:
|
||||||
|
|||||||
@@ -297,10 +297,10 @@ def _pending_tasks(
|
|||||||
# the task status _should_ be set to completed by the LAST runner
|
# the task status _should_ be set to completed by the LAST runner
|
||||||
# it is currently set by the first
|
# it is currently set by the first
|
||||||
# this is definitely a hack
|
# this is definitely a hack
|
||||||
if task.task_id in runner.completed:
|
if task.task_id in runner.completed or task.task_id in runner.in_progress:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(runner.status, RunnerReady) and all(
|
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -32,11 +32,19 @@ def entrypoint(
|
|||||||
# Import main after setting global logger - this lets us just import logger from this module
|
# Import main after setting global logger - this lets us just import logger from this module
|
||||||
try:
|
try:
|
||||||
if bound_instance.is_image_model:
|
if bound_instance.is_image_model:
|
||||||
from exo.worker.runner.image_models.runner import main
|
from exo.worker.runner.image_models.runner import Runner as ImageRunner
|
||||||
else:
|
|
||||||
from exo.worker.runner.llm_inference.runner import main
|
|
||||||
|
|
||||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
runner = ImageRunner(
|
||||||
|
bound_instance, event_sender, task_receiver, cancel_receiver
|
||||||
|
)
|
||||||
|
runner.main()
|
||||||
|
else:
|
||||||
|
from exo.worker.runner.llm_inference.runner import Runner
|
||||||
|
|
||||||
|
runner = Runner(
|
||||||
|
bound_instance, event_sender, task_receiver, cancel_receiver
|
||||||
|
)
|
||||||
|
runner.main()
|
||||||
|
|
||||||
except ClosedResourceError:
|
except ClosedResourceError:
|
||||||
logger.warning("Runner communication closed unexpectedly")
|
logger.warning("Runner communication closed unexpectedly")
|
||||||
|
|||||||
@@ -182,272 +182,266 @@ def _send_image_chunk(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(
|
class Runner:
|
||||||
bound_instance: BoundInstance,
|
def __init__(
|
||||||
event_sender: MpSender[Event],
|
self,
|
||||||
task_receiver: MpReceiver[Task],
|
bound_instance: BoundInstance,
|
||||||
cancel_receiver: MpReceiver[TaskId],
|
event_sender: MpSender[Event],
|
||||||
):
|
task_receiver: MpReceiver[Task],
|
||||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
cancel_receiver: MpReceiver[TaskId],
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
):
|
||||||
|
self.event_sender = event_sender
|
||||||
|
self.task_receiver = task_receiver
|
||||||
|
self.cancel_receiver = cancel_receiver
|
||||||
|
self.bound_instance = bound_instance
|
||||||
|
|
||||||
instance, runner_id, shard_metadata = (
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
bound_instance.instance,
|
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
||||||
bound_instance.bound_runner_id,
|
|
||||||
bound_instance.bound_shard,
|
|
||||||
)
|
|
||||||
device_rank = shard_metadata.device_rank
|
|
||||||
logger.info("hello from the runner")
|
|
||||||
if getattr(shard_metadata, "immediate_exception", False):
|
|
||||||
raise Exception("Fake exception - runner failed to spin up.")
|
|
||||||
if timeout := getattr(shard_metadata, "should_timeout", 0):
|
|
||||||
time.sleep(timeout)
|
|
||||||
|
|
||||||
setup_start_time = time.time()
|
self.instance, self.runner_id, self.shard_metadata = (
|
||||||
cancelled_tasks = set[TaskId]()
|
bound_instance.instance,
|
||||||
|
bound_instance.bound_runner_id,
|
||||||
|
bound_instance.bound_shard,
|
||||||
|
)
|
||||||
|
self.device_rank = self.shard_metadata.device_rank
|
||||||
|
|
||||||
image_model: DistributedImageModel | None = None
|
logger.info("hello from the runner")
|
||||||
group = None
|
if getattr(self.shard_metadata, "immediate_exception", False):
|
||||||
|
raise Exception("Fake exception - runner failed to spin up.")
|
||||||
|
if timeout := getattr(self.shard_metadata, "should_timeout", 0):
|
||||||
|
time.sleep(timeout)
|
||||||
|
|
||||||
current_status: RunnerStatus = RunnerIdle()
|
self.setup_start_time = time.time()
|
||||||
logger.info("runner created")
|
self.cancelled_tasks = set[TaskId]()
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
self.image_model: DistributedImageModel | None = None
|
||||||
)
|
self.group = None
|
||||||
seen = set[TaskId]()
|
|
||||||
with task_receiver as tasks:
|
self.current_status: RunnerStatus = RunnerIdle()
|
||||||
for task in tasks:
|
logger.info("runner created")
|
||||||
if task.task_id in seen:
|
self.update_status(RunnerIdle())
|
||||||
logger.warning("repeat task - potential error")
|
self.seen = set[TaskId]()
|
||||||
seen.add(task.task_id)
|
|
||||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
def update_status(self, status: RunnerStatus):
|
||||||
event_sender.send(
|
self.current_status = status
|
||||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
self.event_sender.send(
|
||||||
|
RunnerStatusUpdated(
|
||||||
|
runner_id=self.runner_id, runner_status=self.current_status
|
||||||
)
|
)
|
||||||
match task:
|
)
|
||||||
case ConnectToGroup() if isinstance(
|
|
||||||
current_status, (RunnerIdle, RunnerFailed)
|
|
||||||
):
|
|
||||||
logger.info("runner connecting")
|
|
||||||
current_status = RunnerConnecting()
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
group = initialize_mlx(bound_instance)
|
|
||||||
|
|
||||||
logger.info("runner connected")
|
def send_task_status(self, task: Task, status: TaskStatus):
|
||||||
current_status = RunnerConnected()
|
self.event_sender.send(
|
||||||
|
TaskStatusUpdated(task_id=task.task_id, task_status=status)
|
||||||
|
)
|
||||||
|
|
||||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
def acknowledge_task(self, task: Task):
|
||||||
case LoadModel() if (
|
self.event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||||
isinstance(current_status, RunnerConnected) and group is not None
|
|
||||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
|
||||||
current_status = RunnerLoading()
|
|
||||||
logger.info("runner loading")
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
assert (
|
def main(self):
|
||||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
with self.task_receiver as tasks:
|
||||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
for task in tasks:
|
||||||
), f"Incorrect model task(s): {shard_metadata.model_card.tasks}"
|
if task.task_id in self.seen:
|
||||||
|
logger.warning("repeat task - potential error")
|
||||||
image_model = initialize_image_model(bound_instance)
|
self.seen.add(task.task_id)
|
||||||
current_status = RunnerLoaded()
|
self.cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
logger.info("runner loaded")
|
self.send_task_status(task, TaskStatus.Running)
|
||||||
|
self.handle_task(task)
|
||||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
was_cancelled = (task.task_id in self.cancelled_tasks) or (
|
||||||
current_status = RunnerWarmingUp()
|
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
|
||||||
logger.info("runner warming up")
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
logger.info(f"warming up inference for instance: {instance}")
|
|
||||||
|
|
||||||
assert image_model
|
|
||||||
image = warmup_image_generator(model=image_model)
|
|
||||||
if image is not None:
|
|
||||||
logger.info(f"warmed up by generating {image.size} image")
|
|
||||||
else:
|
|
||||||
logger.info("warmup completed (non-primary node)")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_status = RunnerReady()
|
|
||||||
logger.info("runner ready")
|
|
||||||
|
|
||||||
case ImageGeneration(
|
|
||||||
task_params=task_params, command_id=command_id
|
|
||||||
) if isinstance(current_status, RunnerReady):
|
|
||||||
assert image_model
|
|
||||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
|
||||||
current_status = RunnerRunning()
|
|
||||||
logger.info("runner running")
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
try:
|
|
||||||
image_index = 0
|
|
||||||
for response in generate_image(
|
|
||||||
model=image_model, task=task_params
|
|
||||||
):
|
|
||||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
|
||||||
|
|
||||||
if is_primary_output:
|
|
||||||
match response:
|
|
||||||
case PartialImageResponse():
|
|
||||||
logger.info(
|
|
||||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
|
||||||
)
|
|
||||||
_process_image_response(
|
|
||||||
response,
|
|
||||||
command_id,
|
|
||||||
shard_metadata,
|
|
||||||
event_sender,
|
|
||||||
image_index,
|
|
||||||
)
|
|
||||||
case ImageGenerationResponse():
|
|
||||||
logger.info("sending final ImageChunk")
|
|
||||||
_process_image_response(
|
|
||||||
response,
|
|
||||||
command_id,
|
|
||||||
shard_metadata,
|
|
||||||
event_sender,
|
|
||||||
image_index,
|
|
||||||
)
|
|
||||||
image_index += 1
|
|
||||||
# can we make this more explicit?
|
|
||||||
except Exception as e:
|
|
||||||
if _is_primary_output_node(shard_metadata):
|
|
||||||
event_sender.send(
|
|
||||||
ChunkGenerated(
|
|
||||||
command_id=command_id,
|
|
||||||
chunk=ErrorChunk(
|
|
||||||
model=shard_metadata.model_card.model_id,
|
|
||||||
finish_reason="error",
|
|
||||||
error_message=str(e),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
|
|
||||||
|
|
||||||
current_status = RunnerReady()
|
|
||||||
logger.info("runner ready")
|
|
||||||
|
|
||||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
|
||||||
isinstance(current_status, RunnerReady)
|
|
||||||
):
|
|
||||||
assert image_model
|
|
||||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
|
||||||
current_status = RunnerRunning()
|
|
||||||
logger.info("runner running")
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
try:
|
|
||||||
image_index = 0
|
|
||||||
for response in generate_image(
|
|
||||||
model=image_model, task=task_params
|
|
||||||
):
|
|
||||||
if _is_primary_output_node(shard_metadata):
|
|
||||||
match response:
|
|
||||||
case PartialImageResponse():
|
|
||||||
logger.info(
|
|
||||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
|
||||||
)
|
|
||||||
_process_image_response(
|
|
||||||
response,
|
|
||||||
command_id,
|
|
||||||
shard_metadata,
|
|
||||||
event_sender,
|
|
||||||
image_index,
|
|
||||||
)
|
|
||||||
case ImageGenerationResponse():
|
|
||||||
logger.info("sending final ImageChunk")
|
|
||||||
_process_image_response(
|
|
||||||
response,
|
|
||||||
command_id,
|
|
||||||
shard_metadata,
|
|
||||||
event_sender,
|
|
||||||
image_index,
|
|
||||||
)
|
|
||||||
image_index += 1
|
|
||||||
except Exception as e:
|
|
||||||
if _is_primary_output_node(shard_metadata):
|
|
||||||
event_sender.send(
|
|
||||||
ChunkGenerated(
|
|
||||||
command_id=command_id,
|
|
||||||
chunk=ErrorChunk(
|
|
||||||
model=shard_metadata.model_card.model_id,
|
|
||||||
finish_reason="error",
|
|
||||||
error_message=str(e),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
|
|
||||||
|
|
||||||
current_status = RunnerReady()
|
|
||||||
logger.info("runner ready")
|
|
||||||
|
|
||||||
case Shutdown():
|
|
||||||
current_status = RunnerShuttingDown()
|
|
||||||
logger.info("runner shutting down")
|
|
||||||
if not TYPE_CHECKING:
|
|
||||||
del image_model, group
|
|
||||||
mx.clear_cache()
|
|
||||||
import gc
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
current_status = RunnerShutdown()
|
|
||||||
case _:
|
|
||||||
raise ValueError(
|
|
||||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
|
||||||
)
|
|
||||||
was_cancelled = (task.task_id in cancelled_tasks) or (
|
|
||||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
|
||||||
)
|
|
||||||
if not was_cancelled:
|
|
||||||
event_sender.send(
|
|
||||||
TaskStatusUpdated(
|
|
||||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
event_sender.send(
|
if not was_cancelled:
|
||||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
self.send_task_status(task, TaskStatus.Complete)
|
||||||
)
|
self.update_status(self.current_status)
|
||||||
|
|
||||||
if isinstance(current_status, RunnerShutdown):
|
if isinstance(self.current_status, RunnerShutdown):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def handle_task(self, task: Task):
|
||||||
|
match task:
|
||||||
|
case ConnectToGroup() if isinstance(
|
||||||
|
self.current_status, (RunnerIdle, RunnerFailed)
|
||||||
|
):
|
||||||
|
logger.info("runner connecting")
|
||||||
|
self.update_status(RunnerConnecting())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
self.group = initialize_mlx(self.bound_instance)
|
||||||
|
|
||||||
|
logger.info("runner connected")
|
||||||
|
self.current_status = RunnerConnected()
|
||||||
|
|
||||||
|
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||||
|
case LoadModel() if (
|
||||||
|
isinstance(self.current_status, RunnerConnected)
|
||||||
|
and self.group is not None
|
||||||
|
) or (isinstance(self.current_status, RunnerIdle) and self.group is None):
|
||||||
|
logger.info("runner loading")
|
||||||
|
self.update_status(RunnerLoading())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
ModelTask.TextToImage in self.shard_metadata.model_card.tasks
|
||||||
|
or ModelTask.ImageToImage in self.shard_metadata.model_card.tasks
|
||||||
|
), f"Incorrect model task(s): {self.shard_metadata.model_card.tasks}"
|
||||||
|
|
||||||
|
self.image_model = initialize_image_model(self.bound_instance)
|
||||||
|
self.current_status = RunnerLoaded()
|
||||||
|
logger.info("runner loaded")
|
||||||
|
|
||||||
|
case StartWarmup() if isinstance(self.current_status, RunnerLoaded):
|
||||||
|
logger.info("runner warming up")
|
||||||
|
self.update_status(RunnerWarmingUp())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
logger.info(f"warming up inference for instance: {self.instance}")
|
||||||
|
|
||||||
|
assert self.image_model
|
||||||
|
image = warmup_image_generator(model=self.image_model)
|
||||||
|
if image is not None:
|
||||||
|
logger.info(f"warmed up by generating {image.size} image")
|
||||||
|
else:
|
||||||
|
logger.info("warmup completed (non-primary node)")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"runner initialized in {time.time() - self.setup_start_time} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_status = RunnerReady()
|
||||||
|
logger.info("runner ready")
|
||||||
|
|
||||||
|
case ImageGeneration(task_params=task_params, command_id=command_id) if (
|
||||||
|
isinstance(self.current_status, RunnerReady)
|
||||||
|
):
|
||||||
|
assert self.image_model
|
||||||
|
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||||
|
logger.info("runner running")
|
||||||
|
self.update_status(RunnerRunning())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_index = 0
|
||||||
|
for response in generate_image(
|
||||||
|
model=self.image_model, task=task_params
|
||||||
|
):
|
||||||
|
is_primary_output = _is_primary_output_node(self.shard_metadata)
|
||||||
|
|
||||||
|
if is_primary_output:
|
||||||
|
match response:
|
||||||
|
case PartialImageResponse():
|
||||||
|
logger.info(
|
||||||
|
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||||
|
)
|
||||||
|
_process_image_response(
|
||||||
|
response,
|
||||||
|
command_id,
|
||||||
|
self.shard_metadata,
|
||||||
|
self.event_sender,
|
||||||
|
image_index,
|
||||||
|
)
|
||||||
|
case ImageGenerationResponse():
|
||||||
|
logger.info("sending final ImageChunk")
|
||||||
|
_process_image_response(
|
||||||
|
response,
|
||||||
|
command_id,
|
||||||
|
self.shard_metadata,
|
||||||
|
self.event_sender,
|
||||||
|
image_index,
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
# can we make this more explicit?
|
||||||
|
except Exception as e:
|
||||||
|
if _is_primary_output_node(self.shard_metadata):
|
||||||
|
self.event_sender.send(
|
||||||
|
ChunkGenerated(
|
||||||
|
command_id=command_id,
|
||||||
|
chunk=ErrorChunk(
|
||||||
|
model=self.shard_metadata.model_card.model_id,
|
||||||
|
finish_reason="error",
|
||||||
|
error_message=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
_send_traces_if_enabled(
|
||||||
|
self.event_sender, task.task_id, self.device_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_status = RunnerReady()
|
||||||
|
logger.info("runner ready")
|
||||||
|
|
||||||
|
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||||
|
isinstance(self.current_status, RunnerReady)
|
||||||
|
):
|
||||||
|
assert self.image_model
|
||||||
|
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||||
|
logger.info("runner running")
|
||||||
|
self.update_status(RunnerRunning())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_index = 0
|
||||||
|
for response in generate_image(
|
||||||
|
model=self.image_model, task=task_params
|
||||||
|
):
|
||||||
|
if _is_primary_output_node(self.shard_metadata):
|
||||||
|
match response:
|
||||||
|
case PartialImageResponse():
|
||||||
|
logger.info(
|
||||||
|
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||||
|
)
|
||||||
|
_process_image_response(
|
||||||
|
response,
|
||||||
|
command_id,
|
||||||
|
self.shard_metadata,
|
||||||
|
self.event_sender,
|
||||||
|
image_index,
|
||||||
|
)
|
||||||
|
case ImageGenerationResponse():
|
||||||
|
logger.info("sending final ImageChunk")
|
||||||
|
_process_image_response(
|
||||||
|
response,
|
||||||
|
command_id,
|
||||||
|
self.shard_metadata,
|
||||||
|
self.event_sender,
|
||||||
|
image_index,
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
except Exception as e:
|
||||||
|
if _is_primary_output_node(self.shard_metadata):
|
||||||
|
self.event_sender.send(
|
||||||
|
ChunkGenerated(
|
||||||
|
command_id=command_id,
|
||||||
|
chunk=ErrorChunk(
|
||||||
|
model=self.shard_metadata.model_card.model_id,
|
||||||
|
finish_reason="error",
|
||||||
|
error_message=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
_send_traces_if_enabled(
|
||||||
|
self.event_sender, task.task_id, self.device_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_status = RunnerReady()
|
||||||
|
logger.info("runner ready")
|
||||||
|
|
||||||
|
case Shutdown():
|
||||||
|
logger.info("runner shutting down")
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
del self.image_model, self.group
|
||||||
|
mx.clear_cache()
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
self.update_status(RunnerShuttingDown())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
self.current_status = RunnerShutdown()
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
f"Received {task.__class__.__name__} outside of state machine in {self.current_status=}"
|
||||||
|
)
|
||||||
|
|||||||
346
src/exo/worker/runner/llm_inference/batch_generator.py
Normal file
346
src/exo/worker/runner/llm_inference/batch_generator.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import deque
|
||||||
|
from collections.abc import Generator, Iterable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
|
||||||
|
from exo.shared.types.common import ModelId
|
||||||
|
from exo.shared.types.events import ChunkGenerated, Event
|
||||||
|
from exo.shared.types.mlx import Model
|
||||||
|
from exo.shared.types.tasks import TaskId, TextGeneration
|
||||||
|
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||||
|
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||||
|
from exo.utils.channels import MpReceiver, MpSender
|
||||||
|
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||||
|
from exo.worker.engines.mlx.generator.generate import (
|
||||||
|
PrefillCancelled,
|
||||||
|
mlx_generate,
|
||||||
|
warmup_inference,
|
||||||
|
)
|
||||||
|
from exo.worker.engines.mlx.utils_mlx import (
|
||||||
|
apply_chat_template,
|
||||||
|
mx_any,
|
||||||
|
)
|
||||||
|
from exo.worker.runner.bootstrap import logger
|
||||||
|
|
||||||
|
from .model_output_parsers import apply_all_parsers
|
||||||
|
from .tool_parsers import ToolParser
|
||||||
|
|
||||||
|
|
||||||
|
class Cancelled:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Finished:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorQueue[T]:
|
||||||
|
def __init__(self):
|
||||||
|
self._q = deque[T]()
|
||||||
|
|
||||||
|
def push(self, t: T):
|
||||||
|
self._q.append(t)
|
||||||
|
|
||||||
|
def gen(self) -> Generator[T | None]:
|
||||||
|
while True:
|
||||||
|
if len(self._q) == 0:
|
||||||
|
yield None
|
||||||
|
else:
|
||||||
|
yield self._q.popleft()
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceGenerator(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def warmup(self) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def submit(
|
||||||
|
self,
|
||||||
|
task: TextGeneration,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
) -> Iterable[
|
||||||
|
tuple[TaskId, ToolCallResponse | GenerationResponse | Cancelled | Finished]
|
||||||
|
]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||||
|
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||||
|
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
|
||||||
|
"""Check for debug prompt triggers in the input."""
|
||||||
|
from exo.worker.engines.mlx.utils_mlx import mlx_force_oom
|
||||||
|
|
||||||
|
if len(task_params.input) == 0:
|
||||||
|
return
|
||||||
|
prompt = task_params.input[0].content
|
||||||
|
if not prompt:
|
||||||
|
return
|
||||||
|
if EXO_RUNNER_MUST_FAIL in prompt:
|
||||||
|
raise Exception("Artificial runner exception - for testing purposes only.")
|
||||||
|
if EXO_RUNNER_MUST_OOM in prompt:
|
||||||
|
mlx_force_oom()
|
||||||
|
if EXO_RUNNER_MUST_TIMEOUT in prompt:
|
||||||
|
time.sleep(100)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class SequentialGenerator(InferenceGenerator):
|
||||||
|
model: Model
|
||||||
|
tokenizer: TokenizerWrapper
|
||||||
|
group: mx.distributed.Group | None
|
||||||
|
kv_prefix_cache: KVPrefixCache | None
|
||||||
|
tool_parser: ToolParser | None
|
||||||
|
model_id: ModelId
|
||||||
|
device_rank: int
|
||||||
|
cancel_receiver: MpReceiver[TaskId]
|
||||||
|
event_sender: MpSender[Event]
|
||||||
|
check_for_cancel_every: int = 50
|
||||||
|
|
||||||
|
_cancelled_tasks: set[TaskId] = field(default_factory=set, init=False)
|
||||||
|
_maybe_queue: list[TextGeneration] = field(default_factory=list, init=False)
|
||||||
|
_queue: deque[TextGeneration] = field(default_factory=deque, init=False)
|
||||||
|
_active: (
|
||||||
|
tuple[
|
||||||
|
TextGeneration,
|
||||||
|
# mlx generator that does work
|
||||||
|
Generator[GenerationResponse],
|
||||||
|
# queue that the 1st generator should push to and 3rd generator should pull from
|
||||||
|
GeneratorQueue[GenerationResponse],
|
||||||
|
# generator to get parsed outputs
|
||||||
|
Generator[GenerationResponse | ToolCallResponse | None],
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
) = field(default=None, init=False)
|
||||||
|
|
||||||
|
def warmup(self):
|
||||||
|
logger.info(f"warming up inference for instance: {self.model_id}")
|
||||||
|
|
||||||
|
t = time.monotonic()
|
||||||
|
toks = warmup_inference(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
group=self.group,
|
||||||
|
)
|
||||||
|
logger.info(f"warmed up by generating {toks} tokens")
|
||||||
|
check_for_cancel_every = min(
|
||||||
|
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||||
|
)
|
||||||
|
if self.group is not None:
|
||||||
|
self.check_for_cancel_every = int(
|
||||||
|
mx.max(
|
||||||
|
mx.distributed.all_gather(
|
||||||
|
mx.array([check_for_cancel_every]),
|
||||||
|
group=self.group,
|
||||||
|
)
|
||||||
|
).item()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"runner checking for cancellation every {check_for_cancel_every} tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
def submit(
|
||||||
|
self,
|
||||||
|
task: TextGeneration,
|
||||||
|
) -> None:
|
||||||
|
self._cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
|
self._maybe_queue.append(task)
|
||||||
|
|
||||||
|
def agree_on_tasks(self) -> None:
|
||||||
|
"""Agree between all ranks about the task ordering (some may have received in different order or not at all)."""
|
||||||
|
agreed, different = mx_all_gather_tasks(self._maybe_queue, self.group)
|
||||||
|
self._queue.extend(task for task in self._maybe_queue if task in agreed)
|
||||||
|
self._maybe_queue = [task for task in self._maybe_queue if task in different]
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
) -> Iterable[
|
||||||
|
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
|
||||||
|
]:
|
||||||
|
if self._active is None:
|
||||||
|
self.agree_on_tasks()
|
||||||
|
|
||||||
|
if self._queue:
|
||||||
|
self._start_next()
|
||||||
|
else:
|
||||||
|
return map(lambda task: (task, Cancelled()), self._cancelled_tasks)
|
||||||
|
|
||||||
|
assert self._active is not None
|
||||||
|
|
||||||
|
task, mlx_gen, queue, output_generator = self._active
|
||||||
|
response = None
|
||||||
|
try:
|
||||||
|
queue.push(next(mlx_gen))
|
||||||
|
response = next(output_generator)
|
||||||
|
except (StopIteration, PrefillCancelled):
|
||||||
|
response = Finished()
|
||||||
|
self._active = None
|
||||||
|
if self._queue:
|
||||||
|
self._start_next()
|
||||||
|
except Exception as e:
|
||||||
|
self._send_error(task, e)
|
||||||
|
self._active = None
|
||||||
|
raise
|
||||||
|
return itertools.chain(
|
||||||
|
[] if response is None else [(task.task_id, response)],
|
||||||
|
map(lambda task: (task, Cancelled()), self._cancelled_tasks),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _start_next(self) -> None:
|
||||||
|
task = self._queue.popleft()
|
||||||
|
try:
|
||||||
|
mlx_gen = self._build_generator(task)
|
||||||
|
except Exception as e:
|
||||||
|
self._send_error(task, e)
|
||||||
|
raise
|
||||||
|
queue = GeneratorQueue[GenerationResponse]()
|
||||||
|
output_generator = apply_all_parsers(
|
||||||
|
queue.gen(),
|
||||||
|
apply_chat_template(self.tokenizer, task.task_params),
|
||||||
|
self.tool_parser,
|
||||||
|
self.tokenizer,
|
||||||
|
type(self.model),
|
||||||
|
self.model_id,
|
||||||
|
task.task_params.tools,
|
||||||
|
)
|
||||||
|
self._active = (task, mlx_gen, queue, output_generator)
|
||||||
|
|
||||||
|
def _send_error(self, task: TextGeneration, e: Exception) -> None:
|
||||||
|
if self.device_rank == 0:
|
||||||
|
self.event_sender.send(
|
||||||
|
ChunkGenerated(
|
||||||
|
command_id=task.command_id,
|
||||||
|
chunk=ErrorChunk(
|
||||||
|
model=self.model_id,
|
||||||
|
finish_reason="error",
|
||||||
|
error_message=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]:
|
||||||
|
_check_for_debug_prompts(task.task_params)
|
||||||
|
prompt = apply_chat_template(self.tokenizer, task.task_params)
|
||||||
|
|
||||||
|
def on_prefill_progress(processed: int, total: int) -> None:
|
||||||
|
if self.device_rank == 0:
|
||||||
|
self.event_sender.send(
|
||||||
|
ChunkGenerated(
|
||||||
|
command_id=task.command_id,
|
||||||
|
chunk=PrefillProgressChunk(
|
||||||
|
model=self.model_id,
|
||||||
|
processed_tokens=processed,
|
||||||
|
total_tokens=total,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def distributed_prompt_progress_callback() -> None:
|
||||||
|
self._cancelled_tasks.update(self.cancel_receiver.collect())
|
||||||
|
want_to_cancel = (task.task_id in self._cancelled_tasks) or (
|
||||||
|
TaskId("CANCEL_CURRENT_TASK") in self._cancelled_tasks
|
||||||
|
)
|
||||||
|
if mx_any(want_to_cancel, self.group):
|
||||||
|
raise PrefillCancelled()
|
||||||
|
|
||||||
|
self.agree_on_tasks()
|
||||||
|
|
||||||
|
tokens_since_cancel_check = self.check_for_cancel_every
|
||||||
|
|
||||||
|
def on_generation_token() -> None:
|
||||||
|
nonlocal tokens_since_cancel_check
|
||||||
|
tokens_since_cancel_check += 1
|
||||||
|
if tokens_since_cancel_check >= self.check_for_cancel_every:
|
||||||
|
tokens_since_cancel_check = 0
|
||||||
|
self._cancelled_tasks.update(self.cancel_receiver.collect())
|
||||||
|
want_to_cancel = (task.task_id in self._cancelled_tasks) or (
|
||||||
|
TaskId("CANCEL_CURRENT_TASK") in self._cancelled_tasks
|
||||||
|
)
|
||||||
|
if mx_any(want_to_cancel, self.group):
|
||||||
|
raise PrefillCancelled()
|
||||||
|
|
||||||
|
self.agree_on_tasks()
|
||||||
|
|
||||||
|
return mlx_generate(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
task=task.task_params,
|
||||||
|
prompt=prompt,
|
||||||
|
kv_prefix_cache=self.kv_prefix_cache,
|
||||||
|
on_prefill_progress=on_prefill_progress,
|
||||||
|
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
|
||||||
|
on_generation_token=on_generation_token,
|
||||||
|
group=self.group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
del self.model, self.tokenizer, self.group
|
||||||
|
|
||||||
|
|
||||||
|
def mx_all_gather_tasks(
|
||||||
|
tasks: list[TextGeneration],
|
||||||
|
group: mx.distributed.Group | None,
|
||||||
|
) -> tuple[list[TextGeneration], list[TextGeneration]]:
|
||||||
|
def encode_task_id(task_id: TaskId) -> list[int]:
|
||||||
|
utf8_task_id = task_id.encode()
|
||||||
|
return [
|
||||||
|
int.from_bytes(utf8_task_id[i : i + 1]) for i in range(len(utf8_task_id))
|
||||||
|
]
|
||||||
|
|
||||||
|
def decode_task_id(encoded_task_id: list[int]) -> TaskId:
|
||||||
|
return TaskId(
|
||||||
|
bytes.decode(b"".join((x).to_bytes(length=1) for x in encoded_task_id))
|
||||||
|
)
|
||||||
|
|
||||||
|
uuid_byte_length = 36
|
||||||
|
|
||||||
|
n_tasks = len(tasks)
|
||||||
|
all_counts = cast(
|
||||||
|
list[int],
|
||||||
|
mx.distributed.all_gather(mx.array([n_tasks]), group=group).tolist(),
|
||||||
|
)
|
||||||
|
max_tasks = max(all_counts)
|
||||||
|
world_size: int = 1 if group is None else group.size()
|
||||||
|
|
||||||
|
if max_tasks == 0:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
padded = [encode_task_id(task.task_id) for task in tasks] + [
|
||||||
|
[0] * uuid_byte_length
|
||||||
|
] * (max_tasks - n_tasks)
|
||||||
|
gathered = cast(
|
||||||
|
list[list[list[int]]],
|
||||||
|
mx.distributed.all_gather(mx.array(padded), group=group)
|
||||||
|
.reshape(world_size, max_tasks, -1)
|
||||||
|
.tolist(),
|
||||||
|
)
|
||||||
|
all_task_ids: list[list[TaskId]] = [
|
||||||
|
[decode_task_id(encoded_task_id) for encoded_task_id in rank_tasks[:count]]
|
||||||
|
for rank_tasks, count in zip(gathered, all_counts, strict=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
agreed_ids: set[TaskId] = set(all_task_ids[0])
|
||||||
|
for rank_tasks in all_task_ids[1:]:
|
||||||
|
agreed_ids &= set(rank_tasks)
|
||||||
|
|
||||||
|
local_tasks = {task.task_id: task for task in tasks}
|
||||||
|
agreed = [local_tasks[tid] for tid in sorted(agreed_ids)]
|
||||||
|
different = [task for task in tasks if task.task_id not in agreed_ids]
|
||||||
|
return agreed, different
|
||||||
378
src/exo/worker/runner/llm_inference/model_output_parsers.py
Normal file
378
src/exo/worker/runner/llm_inference/model_output_parsers.py
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from functools import cache
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||||
|
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||||
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||||
|
HarmonyEncodingName,
|
||||||
|
HarmonyError, # pyright: ignore[reportUnknownVariableType]
|
||||||
|
Role,
|
||||||
|
StreamableParser,
|
||||||
|
load_harmony_encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
from exo.shared.types.api import ToolCallItem
|
||||||
|
from exo.shared.types.common import ModelId
|
||||||
|
from exo.shared.types.mlx import Model
|
||||||
|
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||||
|
from exo.worker.engines.mlx.utils_mlx import (
|
||||||
|
detect_thinking_prompt_suffix,
|
||||||
|
)
|
||||||
|
from exo.worker.runner.bootstrap import logger
|
||||||
|
|
||||||
|
from .tool_parsers import ToolParser
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_gpt_oss_encoding():
|
||||||
|
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
|
def apply_all_parsers(
|
||||||
|
receiver: Generator[GenerationResponse | None],
|
||||||
|
prompt: str,
|
||||||
|
tool_parser: ToolParser | None,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
model_type: type[Model],
|
||||||
|
model_id: ModelId,
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||||
|
mlx_generator = receiver
|
||||||
|
|
||||||
|
if tokenizer.has_thinking:
|
||||||
|
mlx_generator = parse_thinking_models(
|
||||||
|
mlx_generator,
|
||||||
|
tokenizer,
|
||||||
|
starts_in_thinking=detect_thinking_prompt_suffix(prompt, tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
if issubclass(model_type, GptOssModel):
|
||||||
|
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||||
|
elif (
|
||||||
|
issubclass(model_type, DeepseekV32Model)
|
||||||
|
and "deepseek" in model_id.normalize().lower()
|
||||||
|
):
|
||||||
|
mlx_generator = parse_deepseek_v32(mlx_generator)
|
||||||
|
elif tool_parser:
|
||||||
|
mlx_generator = parse_tool_calls(mlx_generator, tool_parser, tools)
|
||||||
|
|
||||||
|
return mlx_generator
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gpt_oss(
|
||||||
|
responses: Generator[GenerationResponse | None],
|
||||||
|
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||||
|
encoding = get_gpt_oss_encoding()
|
||||||
|
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||||
|
thinking = False
|
||||||
|
current_tool_name: str | None = None
|
||||||
|
tool_arg_parts: list[str] = []
|
||||||
|
|
||||||
|
for response in responses:
|
||||||
|
if response is None:
|
||||||
|
yield None
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
stream.process(response.token)
|
||||||
|
except HarmonyError:
|
||||||
|
logger.error("Encountered critical Harmony Error, returning early")
|
||||||
|
return
|
||||||
|
|
||||||
|
delta = stream.last_content_delta
|
||||||
|
ch = stream.current_channel
|
||||||
|
recipient = stream.current_recipient
|
||||||
|
|
||||||
|
# Debug: log every token with state
|
||||||
|
logger.debug(
|
||||||
|
f"parse_gpt_oss token={response.token} text={response.text!r} "
|
||||||
|
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
|
||||||
|
f"state={stream.state} current_tool={current_tool_name!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recipient != current_tool_name:
|
||||||
|
if current_tool_name is not None:
|
||||||
|
prefix = "functions."
|
||||||
|
if current_tool_name.startswith(prefix):
|
||||||
|
current_tool_name = current_tool_name[len(prefix) :]
|
||||||
|
logger.info(
|
||||||
|
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
|
||||||
|
)
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallItem(
|
||||||
|
name=current_tool_name,
|
||||||
|
arguments="".join(tool_arg_parts).strip(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=response.usage,
|
||||||
|
)
|
||||||
|
tool_arg_parts = []
|
||||||
|
current_tool_name = recipient
|
||||||
|
|
||||||
|
# If inside a tool call, accumulate arguments
|
||||||
|
if current_tool_name is not None:
|
||||||
|
if delta:
|
||||||
|
tool_arg_parts.append(delta)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ch == "analysis" and not thinking:
|
||||||
|
thinking = True
|
||||||
|
|
||||||
|
if ch != "analysis" and thinking:
|
||||||
|
thinking = False
|
||||||
|
|
||||||
|
if delta:
|
||||||
|
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
|
||||||
|
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
yield response
|
||||||
|
|
||||||
|
|
||||||
|
def parse_deepseek_v32(
|
||||||
|
responses: Generator[GenerationResponse | None],
|
||||||
|
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||||
|
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
|
||||||
|
|
||||||
|
Uses accumulated-text matching (not per-token marker checks) because
|
||||||
|
DSML markers like <|DSML|function_calls> may span multiple tokens.
|
||||||
|
Also handles <think>...</think> blocks for thinking mode.
|
||||||
|
"""
|
||||||
|
from exo.worker.engines.mlx.dsml_encoding import (
|
||||||
|
THINKING_END,
|
||||||
|
THINKING_START,
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
parse_dsml_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
accumulated = ""
|
||||||
|
in_tool_call = False
|
||||||
|
thinking = False
|
||||||
|
# Tokens buffered while we detect the start of a DSML block
|
||||||
|
pending_buffer: list[GenerationResponse] = []
|
||||||
|
# Text accumulated during a tool call block
|
||||||
|
tool_call_text = ""
|
||||||
|
|
||||||
|
for response in responses:
|
||||||
|
if response is None:
|
||||||
|
yield None
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Handle thinking tags ──
|
||||||
|
if not thinking and THINKING_START in response.text:
|
||||||
|
thinking = True
|
||||||
|
# Yield any text before the <think> tag
|
||||||
|
before = response.text[: response.text.index(THINKING_START)]
|
||||||
|
if before:
|
||||||
|
yield response.model_copy(update={"text": before})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if thinking and THINKING_END in response.text:
|
||||||
|
thinking = False
|
||||||
|
# Yield any text after the </think> tag
|
||||||
|
after = response.text[
|
||||||
|
response.text.index(THINKING_END) + len(THINKING_END) :
|
||||||
|
]
|
||||||
|
if after:
|
||||||
|
yield response.model_copy(update={"text": after, "is_thinking": False})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if thinking:
|
||||||
|
yield response.model_copy(update={"is_thinking": True})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Handle tool call accumulation ──
|
||||||
|
if in_tool_call:
|
||||||
|
tool_call_text += response.text
|
||||||
|
if TOOL_CALLS_END in tool_call_text:
|
||||||
|
# Parse the accumulated DSML block
|
||||||
|
parsed = parse_dsml_output(tool_call_text)
|
||||||
|
if parsed is not None:
|
||||||
|
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=parsed,
|
||||||
|
usage=response.usage,
|
||||||
|
stats=response.stats,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||||
|
)
|
||||||
|
yield response.model_copy(update={"text": tool_call_text})
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
# EOS reached before end marker — yield buffered text as-is
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
logger.info("DSML tool call parsing interrupted by EOS")
|
||||||
|
yield response.model_copy(update={"text": tool_call_text})
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Detect start of tool call block ──
|
||||||
|
accumulated += response.text
|
||||||
|
|
||||||
|
if TOOL_CALLS_START in accumulated:
|
||||||
|
# The start marker might be split across pending_buffer + current token
|
||||||
|
start_idx = accumulated.index(TOOL_CALLS_START)
|
||||||
|
# Yield any pending tokens that are purely before the marker
|
||||||
|
pre_text = accumulated[:start_idx]
|
||||||
|
if pre_text:
|
||||||
|
# Flush pending buffer tokens that contributed text before the marker
|
||||||
|
for buf_resp in pending_buffer:
|
||||||
|
if pre_text:
|
||||||
|
chunk = buf_resp.text
|
||||||
|
if len(chunk) <= len(pre_text):
|
||||||
|
yield buf_resp
|
||||||
|
pre_text = pre_text[len(chunk) :]
|
||||||
|
else:
|
||||||
|
yield buf_resp.model_copy(update={"text": pre_text})
|
||||||
|
pre_text = ""
|
||||||
|
pending_buffer = []
|
||||||
|
tool_call_text = accumulated[start_idx:]
|
||||||
|
accumulated = ""
|
||||||
|
|
||||||
|
# Check if the end marker is already present (entire tool call in one token)
|
||||||
|
if TOOL_CALLS_END in tool_call_text:
|
||||||
|
parsed = parse_dsml_output(tool_call_text)
|
||||||
|
if parsed is not None:
|
||||||
|
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=parsed,
|
||||||
|
usage=response.usage,
|
||||||
|
stats=response.stats,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||||
|
)
|
||||||
|
yield response.model_copy(update={"text": tool_call_text})
|
||||||
|
tool_call_text = ""
|
||||||
|
else:
|
||||||
|
in_tool_call = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if accumulated text might be the start of a DSML marker
|
||||||
|
# Buffer tokens if we see a partial match at the end
|
||||||
|
if _could_be_dsml_prefix(accumulated):
|
||||||
|
pending_buffer.append(response)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# No partial match — flush all pending tokens and the current one
|
||||||
|
for buf_resp in pending_buffer:
|
||||||
|
yield buf_resp
|
||||||
|
pending_buffer = []
|
||||||
|
accumulated = ""
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Flush any remaining pending buffer at generator end
|
||||||
|
for buf_resp in pending_buffer:
|
||||||
|
yield buf_resp
|
||||||
|
|
||||||
|
|
||||||
|
def _could_be_dsml_prefix(text: str) -> bool:
|
||||||
|
"""Check if the end of text could be the start of a DSML function_calls marker.
|
||||||
|
|
||||||
|
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
|
||||||
|
This allows us to buffer tokens until we can determine if a tool call is starting.
|
||||||
|
"""
|
||||||
|
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
|
||||||
|
|
||||||
|
# Only check the last portion of text that could overlap with the marker
|
||||||
|
max_check = len(TOOL_CALLS_START)
|
||||||
|
tail = text[-max_check:] if len(text) > max_check else text
|
||||||
|
|
||||||
|
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
|
||||||
|
for i in range(len(tail)):
|
||||||
|
suffix = tail[i:]
|
||||||
|
if TOOL_CALLS_START.startswith(suffix):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_thinking_models(
|
||||||
|
responses: Generator[GenerationResponse | None],
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
starts_in_thinking: bool = True,
|
||||||
|
) -> Generator[GenerationResponse | None]:
|
||||||
|
"""Route thinking tokens via is_thinking flag.
|
||||||
|
|
||||||
|
Swallows think tag tokens, sets is_thinking on all others.
|
||||||
|
Always yields tokens with finish_reason to avoid hanging the chunk stream.
|
||||||
|
"""
|
||||||
|
in_thinking = starts_in_thinking
|
||||||
|
for response in responses:
|
||||||
|
if response is None:
|
||||||
|
yield None
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_think_tag = (
|
||||||
|
tokenizer.think_end is not None and response.text == tokenizer.think_end
|
||||||
|
) or (
|
||||||
|
tokenizer.think_start is not None and response.text == tokenizer.think_start
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_think_tag:
|
||||||
|
in_thinking = response.text != tokenizer.think_end
|
||||||
|
# Never swallow finish_reason — the chunk stream needs it to terminate.
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
yield response.model_copy(update={"text": "", "is_thinking": False})
|
||||||
|
continue
|
||||||
|
yield response.model_copy(update={"is_thinking": in_thinking})
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_calls(
|
||||||
|
responses: Generator[GenerationResponse | None],
|
||||||
|
tool_parser: ToolParser,
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text_parts: list[str] = []
|
||||||
|
for response in responses:
|
||||||
|
if response is None:
|
||||||
|
yield None
|
||||||
|
continue
|
||||||
|
if not in_tool_call and response.text.startswith(tool_parser.start_parsing):
|
||||||
|
in_tool_call = True
|
||||||
|
|
||||||
|
if in_tool_call:
|
||||||
|
tool_call_text_parts.append(response.text)
|
||||||
|
if response.text.endswith(tool_parser.end_parsing):
|
||||||
|
# parse the actual tool calls from the tool call text
|
||||||
|
parsed = tool_parser.parse("".join(tool_call_text_parts).strip(), tools)
|
||||||
|
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||||
|
if parsed is not None:
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=parsed, usage=response.usage, stats=response.stats
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
|
||||||
|
)
|
||||||
|
response.text = "".join(tool_call_text_parts)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text_parts = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
logger.info(
|
||||||
|
"tool call parsing interrupted, yield partial tool call as text"
|
||||||
|
)
|
||||||
|
response = response.model_copy(
|
||||||
|
update={
|
||||||
|
"text": "".join(tool_call_text_parts),
|
||||||
|
"token": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
else:
|
||||||
|
# fallthrough
|
||||||
|
yield response
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
@@ -9,7 +10,177 @@ from exo.shared.types.api import ToolCallItem
|
|||||||
class ToolParser:
|
class ToolParser:
|
||||||
start_parsing: str
|
start_parsing: str
|
||||||
end_parsing: str
|
end_parsing: str
|
||||||
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
|
_inner_parser: Callable[[str], list[ToolCallItem] | None]
|
||||||
|
|
||||||
|
def parse(
|
||||||
|
self, text: str, tools: list[dict[str, Any]] | None
|
||||||
|
) -> list[ToolCallItem] | None:
|
||||||
|
parsed = self._inner_parser(text)
|
||||||
|
if parsed is None:
|
||||||
|
return None
|
||||||
|
if tools is not None:
|
||||||
|
parsed = _coerce_tool_calls_to_schema(parsed, tools)
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _json_type_matches(value: Any, expected_type: str) -> bool: # pyright: ignore[reportAny]
|
||||||
|
if expected_type == "object":
|
||||||
|
return isinstance(value, dict)
|
||||||
|
if expected_type == "array":
|
||||||
|
return isinstance(value, list)
|
||||||
|
if expected_type == "string":
|
||||||
|
return isinstance(value, str)
|
||||||
|
if expected_type == "integer":
|
||||||
|
return isinstance(value, int) and not isinstance(value, bool)
|
||||||
|
if expected_type == "number":
|
||||||
|
return (isinstance(value, int) and not isinstance(value, bool)) or isinstance(
|
||||||
|
value, float
|
||||||
|
)
|
||||||
|
if expected_type == "boolean":
|
||||||
|
return isinstance(value, bool)
|
||||||
|
if expected_type == "null":
|
||||||
|
return value is None
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_tool_arg_with_schema(value: Any, schema: dict[str, Any]) -> Any: # pyright: ignore[reportAny]
|
||||||
|
schema_type = schema.get("type")
|
||||||
|
|
||||||
|
if isinstance(schema_type, list):
|
||||||
|
for candidate in schema_type: # pyright: ignore[reportUnknownVariableType]
|
||||||
|
if not isinstance(candidate, str):
|
||||||
|
continue
|
||||||
|
if candidate == "null" and value is None:
|
||||||
|
return None
|
||||||
|
candidate_schema = {**schema, "type": candidate}
|
||||||
|
coerced = _coerce_tool_arg_with_schema(value, candidate_schema) # pyright: ignore[reportAny]
|
||||||
|
if _json_type_matches(coerced, candidate):
|
||||||
|
return coerced # pyright: ignore[reportAny]
|
||||||
|
return value # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
if not isinstance(schema_type, str):
|
||||||
|
return value # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
if schema_type == "object":
|
||||||
|
parsed = value # pyright: ignore[reportAny]
|
||||||
|
if isinstance(parsed, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(parsed) # pyright: ignore[reportAny]
|
||||||
|
except Exception:
|
||||||
|
return value # pyright: ignore[reportAny]
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
return value # pyright: ignore[reportAny]
|
||||||
|
properties = schema.get("properties")
|
||||||
|
if not isinstance(properties, dict):
|
||||||
|
return parsed # pyright: ignore[reportUnknownVariableType]
|
||||||
|
return {
|
||||||
|
key: (
|
||||||
|
_coerce_tool_arg_with_schema(prop_value, prop_schema) # pyright: ignore[reportUnknownArgumentType]
|
||||||
|
if isinstance(prop_schema, dict)
|
||||||
|
else prop_value
|
||||||
|
)
|
||||||
|
for key, prop_value in parsed.items() # pyright: ignore[reportUnknownVariableType]
|
||||||
|
for prop_schema in [properties.get(key)] # type: ignore
|
||||||
|
}
|
||||||
|
|
||||||
|
if schema_type == "array":
|
||||||
|
parsed = value # pyright: ignore[reportAny]
|
||||||
|
if isinstance(parsed, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(parsed) # pyright: ignore[reportAny]
|
||||||
|
except Exception:
|
||||||
|
return value # pyright: ignore[reportAny]
|
||||||
|
if not isinstance(parsed, list):
|
||||||
|
return value # pyright: ignore[reportAny]
|
||||||
|
item_schema = schema.get("items")
|
||||||
|
if not isinstance(item_schema, dict):
|
||||||
|
return parsed # pyright: ignore[reportUnknownVariableType]
|
||||||
|
return [_coerce_tool_arg_with_schema(item, item_schema) for item in parsed] # type: ignore
|
||||||
|
|
||||||
|
if schema_type == "integer":
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
if isinstance(value, float) and value.is_integer():
|
||||||
|
return int(value)
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return int(value.strip())
|
||||||
|
except ValueError:
|
||||||
|
return value
|
||||||
|
return value
|
||||||
|
|
||||||
|
if schema_type == "number":
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
num = float(value.strip())
|
||||||
|
if math.isfinite(num):
|
||||||
|
return num
|
||||||
|
except ValueError:
|
||||||
|
return value
|
||||||
|
return value
|
||||||
|
|
||||||
|
if schema_type == "boolean":
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
lowered = value.strip().lower()
|
||||||
|
if lowered == "true":
|
||||||
|
return True
|
||||||
|
if lowered == "false":
|
||||||
|
return False
|
||||||
|
return value
|
||||||
|
|
||||||
|
return value # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_tool_calls_to_schema(
|
||||||
|
tool_calls: list[ToolCallItem], tools: list[dict[str, Any]]
|
||||||
|
) -> list[ToolCallItem]:
|
||||||
|
schema_by_name: dict[str, dict[str, Any]] = {}
|
||||||
|
for tool in tools:
|
||||||
|
function = tool.get("function")
|
||||||
|
if not isinstance(function, dict):
|
||||||
|
continue
|
||||||
|
name = function.get("name") # type: ignore
|
||||||
|
parameters = function.get("parameters") # type: ignore
|
||||||
|
if isinstance(name, str) and isinstance(parameters, dict):
|
||||||
|
schema_by_name[name] = parameters
|
||||||
|
|
||||||
|
if not schema_by_name:
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
coerced_calls: list[ToolCallItem] = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
schema = schema_by_name.get(tool_call.name)
|
||||||
|
if schema is None:
|
||||||
|
coerced_calls.append(tool_call)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args = json.loads(tool_call.arguments) # pyright: ignore[reportAny]
|
||||||
|
except Exception:
|
||||||
|
coerced_calls.append(tool_call)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(parsed_args, dict):
|
||||||
|
coerced_calls.append(tool_call)
|
||||||
|
continue
|
||||||
|
|
||||||
|
coerced_args = _coerce_tool_arg_with_schema(parsed_args, schema) # pyright: ignore[reportAny]
|
||||||
|
if not isinstance(coerced_args, dict):
|
||||||
|
coerced_calls.append(tool_call)
|
||||||
|
continue
|
||||||
|
|
||||||
|
coerced_calls.append(
|
||||||
|
tool_call.model_copy(update={"arguments": json.dumps(coerced_args)})
|
||||||
|
)
|
||||||
|
return coerced_calls
|
||||||
|
|
||||||
|
|
||||||
def make_mlx_parser(
|
def make_mlx_parser(
|
||||||
@@ -33,7 +204,7 @@ def make_mlx_parser(
|
|||||||
return ToolParser(
|
return ToolParser(
|
||||||
start_parsing=tool_call_start,
|
start_parsing=tool_call_start,
|
||||||
end_parsing=tool_call_end,
|
end_parsing=tool_call_end,
|
||||||
parse_tool_calls=parse_tool_calls,
|
_inner_parser=parse_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +233,7 @@ def make_json_parser() -> ToolParser:
|
|||||||
return ToolParser(
|
return ToolParser(
|
||||||
start_parsing="<tool_call>",
|
start_parsing="<tool_call>",
|
||||||
end_parsing="</tool_call>",
|
end_parsing="</tool_call>",
|
||||||
parse_tool_calls=_parse_json_calls,
|
_inner_parser=_parse_json_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,13 +12,22 @@ from anyio import (
|
|||||||
)
|
)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from exo.shared.types.chunks import ErrorChunk
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
|
ChunkGenerated,
|
||||||
Event,
|
Event,
|
||||||
RunnerStatusUpdated,
|
RunnerStatusUpdated,
|
||||||
TaskAcknowledged,
|
TaskAcknowledged,
|
||||||
TaskStatusUpdated,
|
TaskStatusUpdated,
|
||||||
)
|
)
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import (
|
||||||
|
ImageEdits,
|
||||||
|
ImageGeneration,
|
||||||
|
Task,
|
||||||
|
TaskId,
|
||||||
|
TaskStatus,
|
||||||
|
TextGeneration,
|
||||||
|
)
|
||||||
from exo.shared.types.worker.instances import BoundInstance
|
from exo.shared.types.worker.instances import BoundInstance
|
||||||
from exo.shared.types.worker.runners import (
|
from exo.shared.types.worker.runners import (
|
||||||
RunnerConnecting,
|
RunnerConnecting,
|
||||||
@@ -52,6 +61,7 @@ class RunnerSupervisor:
|
|||||||
_tg: TaskGroup = field(default_factory=TaskGroup, init=False)
|
_tg: TaskGroup = field(default_factory=TaskGroup, init=False)
|
||||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||||
|
in_progress: dict[TaskId, Task] = field(default_factory=dict, init=False)
|
||||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||||
_cancel_watch_runner: anyio.CancelScope = field(
|
_cancel_watch_runner: anyio.CancelScope = field(
|
||||||
@@ -147,9 +157,11 @@ class RunnerSupervisor:
|
|||||||
logger.info(f"Starting task {task}")
|
logger.info(f"Starting task {task}")
|
||||||
event = anyio.Event()
|
event = anyio.Event()
|
||||||
self.pending[task.task_id] = event
|
self.pending[task.task_id] = event
|
||||||
|
self.in_progress[task.task_id] = task
|
||||||
try:
|
try:
|
||||||
await self._task_sender.send_async(task)
|
await self._task_sender.send_async(task)
|
||||||
except ClosedResourceError:
|
except ClosedResourceError:
|
||||||
|
self.in_progress.pop(task.task_id, None)
|
||||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||||
return
|
return
|
||||||
await event.wait()
|
await event.wait()
|
||||||
@@ -157,10 +169,17 @@ class RunnerSupervisor:
|
|||||||
async def cancel_task(self, task_id: TaskId):
|
async def cancel_task(self, task_id: TaskId):
|
||||||
if task_id in self.completed:
|
if task_id in self.completed:
|
||||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||||
|
self.cancelled.add(task_id)
|
||||||
return
|
return
|
||||||
self.cancelled.add(task_id)
|
self.cancelled.add(task_id)
|
||||||
with anyio.move_on_after(0.5) as scope:
|
with anyio.move_on_after(0.5) as scope:
|
||||||
await self._cancel_sender.send_async(task_id)
|
try:
|
||||||
|
await self._cancel_sender.send_async(task_id)
|
||||||
|
except ClosedResourceError:
|
||||||
|
# typically occurs when trying to shut down a failed instance
|
||||||
|
logger.warning(
|
||||||
|
f"Cancelling task {task_id} failed, runner closed communication"
|
||||||
|
)
|
||||||
if scope.cancel_called:
|
if scope.cancel_called:
|
||||||
logger.error("RunnerSupervisor cancel pipe blocked")
|
logger.error("RunnerSupervisor cancel pipe blocked")
|
||||||
await self._check_runner(TimeoutError("cancel pipe blocked"))
|
await self._check_runner(TimeoutError("cancel pipe blocked"))
|
||||||
@@ -189,6 +208,7 @@ class RunnerSupervisor:
|
|||||||
RunnerShuttingDown,
|
RunnerShuttingDown,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self.in_progress.pop(event.task_id, None)
|
||||||
self.completed.add(event.task_id)
|
self.completed.add(event.task_id)
|
||||||
await self._event_sender.send(event)
|
await self._event_sender.send(event)
|
||||||
except (ClosedResourceError, BrokenResourceError) as e:
|
except (ClosedResourceError, BrokenResourceError) as e:
|
||||||
@@ -233,6 +253,22 @@ class RunnerSupervisor:
|
|||||||
|
|
||||||
logger.opt(exception=e).error(f"Runner terminated with {cause}")
|
logger.opt(exception=e).error(f"Runner terminated with {cause}")
|
||||||
|
|
||||||
|
for task in self.in_progress.values():
|
||||||
|
if isinstance(task, (TextGeneration, ImageGeneration, ImageEdits)):
|
||||||
|
with anyio.CancelScope(shield=True):
|
||||||
|
await self._event_sender.send(
|
||||||
|
ChunkGenerated(
|
||||||
|
command_id=task.command_id,
|
||||||
|
chunk=ErrorChunk(
|
||||||
|
model=self.shard_metadata.model_card.model_id,
|
||||||
|
error_message=(
|
||||||
|
"Runner shutdown before completing command "
|
||||||
|
f"({cause})"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.status = RunnerFailed(error_message=f"Terminated ({cause})")
|
self.status = RunnerFailed(error_message=f"Terminated ({cause})")
|
||||||
with anyio.CancelScope(shield=True):
|
with anyio.CancelScope(shield=True):
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ class FakeRunnerSupervisor:
|
|||||||
bound_instance: BoundInstance
|
bound_instance: BoundInstance
|
||||||
status: RunnerStatus
|
status: RunnerStatus
|
||||||
completed: set[TaskId] = field(default_factory=set)
|
completed: set[TaskId] = field(default_factory=set)
|
||||||
|
in_progress: set[TaskId] = field(default_factory=set)
|
||||||
|
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class OtherTask(BaseTask):
|
class OtherTask(BaseTask):
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import (
|
|||||||
encode_messages,
|
encode_messages,
|
||||||
parse_dsml_output,
|
parse_dsml_output,
|
||||||
)
|
)
|
||||||
from exo.worker.runner.llm_inference.runner import parse_deepseek_v32
|
from exo.worker.runner.llm_inference.model_output_parsers import parse_deepseek_v32
|
||||||
|
|
||||||
# ── Shared fixtures ──────────────────────────────────────────────
|
# ── Shared fixtures ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from typing import Callable
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator
|
||||||
|
import exo.worker.runner.llm_inference.model_output_parsers as mlx_model_output_parsers
|
||||||
import exo.worker.runner.llm_inference.runner as mlx_runner
|
import exo.worker.runner.llm_inference.runner as mlx_runner
|
||||||
from exo.shared.types.chunks import TokenChunk
|
from exo.shared.types.chunks import TokenChunk
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
@@ -114,27 +116,41 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# initialize_mlx returns a mock group
|
# initialize_mlx returns a mock group
|
||||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
monkeypatch.setattr(mlx_batch_generator, "warmup_inference", make_nothin(1))
|
||||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
|
||||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
|
||||||
|
|
||||||
|
def fake_all_gather(
|
||||||
|
tasks: list[TextGeneration], group: object
|
||||||
|
) -> tuple[list[TextGeneration], list[TextGeneration]]:
|
||||||
|
return (tasks, [])
|
||||||
|
|
||||||
|
monkeypatch.setattr(mlx_batch_generator, "mx_all_gather_tasks", fake_all_gather)
|
||||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
monkeypatch.setattr(
|
||||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
mlx_batch_generator, "apply_chat_template", make_nothin("test prompt")
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mlx_model_output_parsers, "detect_thinking_prompt_suffix", make_nothin(False)
|
||||||
|
)
|
||||||
|
|
||||||
def fake_generate(*_1: object, **_2: object):
|
def fake_generate(*_1: object, **_2: object):
|
||||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||||
|
|
||||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
monkeypatch.setattr(mlx_batch_generator, "mlx_generate", fake_generate)
|
||||||
|
|
||||||
|
|
||||||
# Use a fake event_sender to remove test flakiness.
|
# Use a fake event_sender to remove test flakiness.
|
||||||
class EventCollector:
|
class EventCollector:
|
||||||
def __init__(self) -> None:
|
def __init__(self, on_event: Callable[[Event], None] | None = None) -> None:
|
||||||
self.events: list[Event] = []
|
self.events: list[Event] = []
|
||||||
|
self._on_event = on_event
|
||||||
|
|
||||||
def send(self, event: Event) -> None:
|
def send(self, event: Event) -> None:
|
||||||
self.events.append(event)
|
self.events.append(event)
|
||||||
|
if self._on_event:
|
||||||
|
self._on_event(event)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
pass
|
pass
|
||||||
@@ -159,7 +175,7 @@ class MockGroup:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def _run(tasks: Iterable[Task]):
|
def _run(tasks: Iterable[Task], send_after_ready: list[Task] | None = None):
|
||||||
bound_instance = get_bound_mlx_ring_instance(
|
bound_instance = get_bound_mlx_ring_instance(
|
||||||
instance_id=INSTANCE_1_ID,
|
instance_id=INSTANCE_1_ID,
|
||||||
model_id=MODEL_A_ID,
|
model_id=MODEL_A_ID,
|
||||||
@@ -169,7 +185,23 @@ def _run(tasks: Iterable[Task]):
|
|||||||
|
|
||||||
task_sender, task_receiver = mp_channel[Task]()
|
task_sender, task_receiver = mp_channel[Task]()
|
||||||
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
||||||
event_sender = EventCollector()
|
|
||||||
|
on_event: Callable[[Event], None] | None = None
|
||||||
|
if send_after_ready:
|
||||||
|
_saw_running = False
|
||||||
|
|
||||||
|
def _on_event(event: Event) -> None:
|
||||||
|
nonlocal _saw_running
|
||||||
|
if isinstance(event, RunnerStatusUpdated):
|
||||||
|
if isinstance(event.runner_status, RunnerRunning):
|
||||||
|
_saw_running = True
|
||||||
|
elif _saw_running and isinstance(event.runner_status, RunnerReady):
|
||||||
|
for t in send_after_ready:
|
||||||
|
task_sender.send(t)
|
||||||
|
|
||||||
|
on_event = _on_event
|
||||||
|
|
||||||
|
event_sender = EventCollector(on_event=on_event)
|
||||||
|
|
||||||
with task_sender:
|
with task_sender:
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
@@ -183,18 +215,22 @@ def _run(tasks: Iterable[Task]):
|
|||||||
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
|
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
|
||||||
make_nothin(mx.array([1])),
|
make_nothin(mx.array([1])),
|
||||||
):
|
):
|
||||||
mlx_runner.main(
|
runner = mlx_runner.Runner(
|
||||||
bound_instance,
|
bound_instance,
|
||||||
event_sender, # pyright: ignore[reportArgumentType]
|
event_sender, # pyright: ignore[reportArgumentType]
|
||||||
task_receiver,
|
task_receiver,
|
||||||
cancel_receiver,
|
cancel_receiver,
|
||||||
)
|
)
|
||||||
|
runner.main()
|
||||||
|
|
||||||
return event_sender.events
|
return event_sender.events
|
||||||
|
|
||||||
|
|
||||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
events = _run(
|
||||||
|
[INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK],
|
||||||
|
send_after_ready=[SHUTDOWN_TASK],
|
||||||
|
)
|
||||||
|
|
||||||
expected_chunk = ChunkGenerated(
|
expected_chunk = ChunkGenerated(
|
||||||
command_id=COMMAND_1_ID,
|
command_id=COMMAND_1_ID,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import (
|
|||||||
GenerationResponse,
|
GenerationResponse,
|
||||||
ToolCallResponse,
|
ToolCallResponse,
|
||||||
)
|
)
|
||||||
from exo.worker.runner.llm_inference.runner import parse_gpt_oss
|
from exo.worker.runner.llm_inference.model_output_parsers import parse_gpt_oss
|
||||||
|
|
||||||
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
|
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
|
||||||
# These are stable since they come from the model's vocabulary.
|
# These are stable since they come from the model's vocabulary.
|
||||||
@@ -107,7 +107,7 @@ def _collect(
|
|||||||
def _gen() -> Generator[GenerationResponse, None, None]:
|
def _gen() -> Generator[GenerationResponse, None, None]:
|
||||||
yield from _make_gen_responses(tokens)
|
yield from _make_gen_responses(tokens)
|
||||||
|
|
||||||
return list(parse_gpt_oss(_gen()))
|
return list(x for x in parse_gpt_oss(_gen()) if x is not None)
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_call(
|
def _get_tool_call(
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
"""Tests for parse_tool_calls generator, especially unclosed tool call handling."""
|
"""Tests for parse_tool_calls generator, especially unclosed tool call handling."""
|
||||||
|
|
||||||
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||||
from exo.worker.runner.llm_inference.runner import parse_tool_calls
|
from exo.worker.runner.llm_inference.model_output_parsers import parse_tool_calls
|
||||||
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser
|
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser
|
||||||
|
|
||||||
|
|
||||||
@@ -40,6 +41,7 @@ class TestParseToolCalls:
|
|||||||
parse_tool_calls(
|
parse_tool_calls(
|
||||||
_make_responses(texts, finish_on_last=False),
|
_make_responses(texts, finish_on_last=False),
|
||||||
_dummy_parser,
|
_dummy_parser,
|
||||||
|
tools=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,6 +55,7 @@ class TestParseToolCalls:
|
|||||||
parse_tool_calls(
|
parse_tool_calls(
|
||||||
_make_responses(texts),
|
_make_responses(texts),
|
||||||
_dummy_parser,
|
_dummy_parser,
|
||||||
|
tools=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -77,9 +80,101 @@ class TestParseToolCalls:
|
|||||||
parse_tool_calls(
|
parse_tool_calls(
|
||||||
_make_responses(texts, finish_on_last=False),
|
_make_responses(texts, finish_on_last=False),
|
||||||
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
|
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
|
||||||
|
tools=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert isinstance(results[0], GenerationResponse)
|
assert isinstance(results[0], GenerationResponse)
|
||||||
assert results[0].text == "<tool_call>bad content</tool_call>"
|
assert results[0].text == "<tool_call>bad content</tool_call>"
|
||||||
|
|
||||||
|
def test_tool_schema_coerces_string_arguments_to_expected_types(self):
|
||||||
|
"""Tool argument values should be coerced using provided JSON schema."""
|
||||||
|
|
||||||
|
def _parser_with_string_args(_text: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"name": "process",
|
||||||
|
"arguments": {
|
||||||
|
"action": "output",
|
||||||
|
"id": "0",
|
||||||
|
"verbose": "true",
|
||||||
|
"temperature": "0.75",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "process",
|
||||||
|
"description": "Manage background processes",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action": {"type": "string"},
|
||||||
|
"id": {"type": "integer"},
|
||||||
|
"verbose": {"type": "boolean"},
|
||||||
|
"temperature": {"type": "number"},
|
||||||
|
},
|
||||||
|
"required": ["action"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(
|
||||||
|
parse_tool_calls(
|
||||||
|
_make_responses(["<tool_call>", "process", "</tool_call>"]),
|
||||||
|
make_mlx_parser(
|
||||||
|
"<tool_call>", "</tool_call>", _parser_with_string_args
|
||||||
|
),
|
||||||
|
tools,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], ToolCallResponse)
|
||||||
|
|
||||||
|
args = json.loads(results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||||
|
assert args == {
|
||||||
|
"action": "output",
|
||||||
|
"id": 0,
|
||||||
|
"verbose": True,
|
||||||
|
"temperature": 0.75,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_schema_coercion_skips_unknown_tools(self):
|
||||||
|
"""If no matching tool schema exists, arguments should remain unchanged."""
|
||||||
|
|
||||||
|
def _parser_with_string_id(_text: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"name": "process",
|
||||||
|
"arguments": {"action": "output", "id": "0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "different_tool",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"id": {"type": "integer"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(
|
||||||
|
parse_tool_calls(
|
||||||
|
_make_responses(["<tool_call>", "process", "</tool_call>"]),
|
||||||
|
make_mlx_parser("<tool_call>", "</tool_call>", _parser_with_string_id),
|
||||||
|
tools,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], ToolCallResponse)
|
||||||
|
|
||||||
|
args = json.loads(results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||||
|
assert args == {"action": "output", "id": "0"}
|
||||||
|
|||||||
@@ -1 +1,93 @@
|
|||||||
# TODO:
|
import multiprocessing as mp
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from exo.shared.models.model_cards import ModelId
|
||||||
|
from exo.shared.types.chunks import ErrorChunk
|
||||||
|
from exo.shared.types.common import CommandId, NodeId
|
||||||
|
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
|
||||||
|
from exo.shared.types.tasks import Task, TaskId, TextGeneration
|
||||||
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||||
|
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||||
|
from exo.shared.types.worker.runners import RunnerFailed, RunnerId
|
||||||
|
from exo.utils.channels import channel, mp_channel
|
||||||
|
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||||
|
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||||
|
|
||||||
|
|
||||||
|
class _DeadProcess:
|
||||||
|
exitcode = -6
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_alive(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def join(self, _timeout: float | None = None) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def terminate(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def kill(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_runner_emits_error_chunk_for_inflight_text_generation() -> None:
|
||||||
|
event_sender, event_receiver = channel[Event]()
|
||||||
|
task_sender, _ = mp_channel[Task]()
|
||||||
|
cancel_sender, _ = mp_channel[TaskId]()
|
||||||
|
_, ev_recv = mp_channel[Event]()
|
||||||
|
|
||||||
|
bound_instance: BoundInstance = get_bound_mlx_ring_instance(
|
||||||
|
instance_id=InstanceId("instance-a"),
|
||||||
|
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||||
|
runner_id=RunnerId("runner-a"),
|
||||||
|
node_id=NodeId("node-a"),
|
||||||
|
)
|
||||||
|
|
||||||
|
supervisor = RunnerSupervisor(
|
||||||
|
shard_metadata=bound_instance.bound_shard,
|
||||||
|
bound_instance=bound_instance,
|
||||||
|
runner_process=cast("mp.Process", cast(object, _DeadProcess())),
|
||||||
|
initialize_timeout=400,
|
||||||
|
_ev_recv=ev_recv,
|
||||||
|
_task_sender=task_sender,
|
||||||
|
_event_sender=event_sender,
|
||||||
|
_cancel_sender=cancel_sender,
|
||||||
|
)
|
||||||
|
|
||||||
|
command_id = CommandId("cmd-a")
|
||||||
|
task = TextGeneration(
|
||||||
|
task_id=TaskId("task-a"),
|
||||||
|
instance_id=bound_instance.instance.instance_id,
|
||||||
|
command_id=command_id,
|
||||||
|
task_params=TextGenerationTaskParams(
|
||||||
|
model=bound_instance.bound_shard.model_card.model_id,
|
||||||
|
input=[InputMessage(role="user", content="hi")],
|
||||||
|
stream=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
supervisor.in_progress[task.task_id] = task
|
||||||
|
supervisor.shutdown = lambda: None
|
||||||
|
|
||||||
|
await supervisor._check_runner(RuntimeError("boom")) # pyright: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
got_chunk = await event_receiver.receive()
|
||||||
|
got_status = await event_receiver.receive()
|
||||||
|
|
||||||
|
assert isinstance(got_chunk, ChunkGenerated)
|
||||||
|
assert got_chunk.command_id == command_id
|
||||||
|
assert isinstance(got_chunk.chunk, ErrorChunk)
|
||||||
|
assert "Runner shutdown before completing command" in got_chunk.chunk.error_message
|
||||||
|
|
||||||
|
assert isinstance(got_status, RunnerStatusUpdated)
|
||||||
|
assert isinstance(got_status.runner_status, RunnerFailed)
|
||||||
|
|
||||||
|
event_sender.close()
|
||||||
|
with anyio.move_on_after(0.1):
|
||||||
|
await event_receiver.aclose()
|
||||||
|
|||||||
Reference in New Issue
Block a user