Compare commits

...

7 Commits

Author SHA1 Message Date
Alex Cheema
97e8da5efd feat: show download availability in model picker
Add download status indicators to the model picker modal:
- Each model group shows a green checkmark if it's downloaded on nodes
  with enough total RAM to run it
- The info (i) modal now includes a "Downloaded on:" section listing
  the friendly names of nodes that have the model
- Availability is computed by checking which nodes have DownloadCompleted
  entries and summing their total RAM against the model's storage size

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 06:11:22 -08:00
Alex Cheema
c381ae64ad feat: show download status indicators on topology nodes
Add per-node download status indicators to the topology view for the
currently selected model. Nodes with completed downloads show a green
checkmark; nodes with pending/ongoing downloads show a clickable
download button.

Also extracts shared download parsing utilities into
dashboard/src/lib/utils/downloads.ts and refactors the downloads page
to use them.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 06:07:48 -08:00
Alex Cheema
41ed7afb3b feat: add model picker modal with grouped models and HF Hub search (#1369)
## Motivation

Reimplements the model picker modal from #1191 on top of the custom
model support branch. Replaces the inline model dropdown with a
full-featured modal that groups models by base model, supports
filtering, favorites, and HuggingFace Hub search.

## Changes

**Backend:**
- Add `family`, `quantization`, `base_model`, `capabilities` metadata
fields to `ModelCard` and all 40 TOML model cards
- Pass new fields through `ModelListModel` and `get_models()` API
response
- Add `GET /models/search` endpoint using
`huggingface_hub.list_models()`

**Dashboard (7 new files):**
- `ModelPickerModal.svelte` — Main modal with search, family filtering,
HuggingFace Hub tab
- `ModelPickerGroup.svelte` — Expandable model group row with
quantization variants
- `FamilySidebar.svelte` — Vertical sidebar with family icons (All,
Favorites, Hub, model families)
- `FamilyLogos.svelte` — SVG icons for each model family
- `ModelFilterPopover.svelte` — Capability and size range filters
- `HuggingFaceResultItem.svelte` — HF search result item with
download/like counts
- `favorites.svelte.ts` — localStorage-backed favorites store

**Integration:**
- Replace inline dropdown in `+page.svelte` with button that opens
`ModelPickerModal`
- Custom models shown in Hub tab with delete support

**Polish:**
- Real brand logos (Meta, Qwen, DeepSeek, OpenAI, GLM, MiniMax, Kimi,
HuggingFace) from Simple Icons / LobeHub
- Clean SVG stroke icons for capabilities (thinking, code, vision, image
gen)
- Consistent `border-exo-yellow/10` borders, descriptive tooltips
throughout
- Cluster memory (used/total) shown in modal header
- Selected model highlight with checkmark for both single and
multi-variant groups
- Cursor pointer on all interactive elements, fix filter popover
click-outside bug
- Custom models now appear in All tab alongside built-in models

## Bug Fix: Gemma 3 EOS tokens

Also included in this branch: fix for Gemma 3 models generating infinite
`<end_of_turn>` tokens. The tokenizer's `eos_token_ids` was missing
token ID 106 (`<end_of_turn>`), so generation never stopped. The fix
appends this token to the EOS list after loading the tokenizer. Also
handles `eos_token_ids` being a `set` (not just a `list`).

## Why It Works

Model metadata (family, capabilities, etc.) is stored directly in TOML
cards rather than derived from heuristics, ensuring accuracy. The modal
groups models by `base_model` field so quantization variants appear
together. Custom models are separated into the Hub tab since they lack
grouping metadata.

## Test Plan

### Manual Testing
- Open dashboard, click model selector to open modal
- Browse models by family sidebar, search, and filters
- Expand model groups to see quantization variants
- Star favorites and verify persistence across page reloads
- Navigate to Hub tab, search and add models
- Verify error messages shown for invalid model IDs
- Run a Gemma 3 model and verify generation stops at `<end_of_turn>`

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `nix fmt` — clean
- `uv run pytest src/` — 173 passed
- `cd dashboard && npm run build` — builds successfully

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 05:56:23 -08:00
Alex Cheema
2063278906 feat: add custom HuggingFace model support (#1368)
## Motivation

Users should be able to run any HuggingFace model, not just the ones we
ship TOML cards for. Continues the aim of #1191 with a minimal
implementation on top of the current TOML model card system.

Custom cards are saved to `~/.exo/custom_model_cards/` rather than the
bundled `resources/inference_model_cards/` because `RESOURCES_DIR` is
read-only in PyInstaller bundles (`sys._MEIPASS`). This also fixes
`fetch_from_hf` which was saving cards to the wrong path (`resources/`
root instead of `resources/inference_model_cards/`).

## Changes

- Add `EXO_CUSTOM_MODEL_CARDS_DIR` constant
(`~/.exo/custom_model_cards/`)
- Update `model_cards.py`: add custom dir to search path, fix
`save_to_custom_dir`, add `delete_custom_card`/`is_custom_card`
- Add `POST /models/add` and `DELETE /models/custom/{model_id}` API
endpoints
- Add `is_custom` field to `ModelListModel` API response
- Dashboard: add custom model input form in dropdown, delete button for
custom models, show actual API errors, auto-select newly added model

## Why It Works

Two separate directories for model cards: the bundled read-only
`resources/inference_model_cards/` for built-in cards, and user-writable
`~/.exo/custom_model_cards/` for custom cards. Both are scanned when
listing models. This works in all environments including PyInstaller
bundles where `RESOURCES_DIR` points to `sys._MEIPASS`.

## Test Plan

### Manual Testing
- Add a custom model via the dropdown (e.g.
`mlx-community/Llama-3.2-1B-Instruct-4bit`)
- Verify it appears in the model list with the delete (x) button
- Delete it and verify it disappears
- Try adding an invalid model ID and verify the actual error is shown

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `uv run pytest src/` — passes
- `cd dashboard && npm run build` — builds

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 05:06:15 -08:00
rltakashige
a0f4f36355 Reduce reliance on internet (#1363)
## Motivation

Offline users currently have to wait for every retry to fail before
being able to launch a model.
For users that restart clusters often or share API keys between devices,
we also spam HuggingFace with downloads every 5 minutes.
These issues are caused by _emit_existing_download_progress being
inefficient.

## Changes

- Only query HuggingFace once while EXO is running (assumption being
that a change should only be reflected on a new EXO session)
- Only query HuggingFace when there is an internet connection (polling
connectivity every 10 seconds)
- Request download progress if we switch from no connectivity ->
connected to reduce the wait.
- Reduce download progress sleep as it's no longer expensive (queries
cache most of the time).
- Reduce retries as 30 is way too many.

## Test Plan

### Manual Testing
Manually tested the behaviour.

### Automated Testing
None, should I add any? We do have some tests for this folder, but they
are probably not too helpful.
2026-02-03 20:03:29 +00:00
Alex Cheema
acb97127bf Normalize TextGenerationTaskParams.input to list[InputMessage] (#1360)
## Motivation

With the addition of the Responses API, we introduced `str |
list[InputMessage]` as the type for `TextGenerationTaskParams.input`
since the Responses API supports sending input as a plain string. But
there was no reason to leak that flexibility past the API adapter
boundary — it just meant every downstream consumer had to do `if
isinstance(messages, str):` checks, adding complexity for no benefit.

## Changes

- Changed `TextGenerationTaskParams.input` from `str |
list[InputMessage]` to `list[InputMessage]`
- Each API adapter (Chat Completions, Claude Messages, Responses) now
normalizes to `list[InputMessage]` at the boundary
- Removed `isinstance(task_params.input, str)` branches in
`utils_mlx.py` and `runner.py`
- Wrapped string inputs in `[InputMessage(role="user", content=...)]` in
the warmup path and all test files

## Why It Works

The API adapters are the only place where we deal with raw user input
formats. By normalizing there, all downstream code (worker, runner, MLX
engine) can just assume `list[InputMessage]` and skip the type-checking
branches. The type system (`basedpyright`) catches any missed call sites
at compile time.

## Test Plan

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `nix fmt` — applied
- `uv run pytest` — 174 passed, 1 skipped

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 06:01:56 -08:00
Evan Quiney
d90605f198 migrate model cards to .toml files (#1354) 2026-02-03 12:32:06 +00:00
94 changed files with 3806 additions and 982 deletions

View File

@@ -142,4 +142,4 @@ jobs:
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

1
.gitignore vendored
View File

@@ -31,3 +31,4 @@ dashboard/.svelte-kit/
# host config snapshots
hosts_*.json
.swp

View File

@@ -108,6 +108,7 @@ class TokenizerWrapper:
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
eos_token_ids: list[int] | set[int] | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
@@ -117,7 +118,7 @@ class TokenizerWrapper:
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | None = ...,
eos_token_ids: list[int] | set[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,

View File

@@ -0,0 +1,73 @@
<script lang="ts">
type FamilyLogoProps = {
family: string;
class?: string;
};
let { family, class: className = "" }: FamilyLogoProps = $props();
</script>
{#if family === "favorites"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{:else if family === "llama" || family === "meta"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M6.915 4.03c-1.968 0-3.683 1.28-4.871 3.113C.704 9.208 0 11.883 0 14.449c0 .706.07 1.369.21 1.973a6.624 6.624 0 0 0 .265.86 5.297 5.297 0 0 0 .371.761c.696 1.159 1.818 1.927 3.593 1.927 1.497 0 2.633-.671 3.965-2.444.76-1.012 1.144-1.626 2.663-4.32l.756-1.339.186-.325c.061.1.121.196.183.3l2.152 3.595c.724 1.21 1.665 2.556 2.47 3.314 1.046.987 1.992 1.22 3.06 1.22 1.075 0 1.876-.355 2.455-.843a3.743 3.743 0 0 0 .81-.973c.542-.939.861-2.127.861-3.745 0-2.72-.681-5.357-2.084-7.45-1.282-1.912-2.957-2.93-4.716-2.93-1.047 0-2.088.467-3.053 1.308-.652.57-1.257 1.29-1.82 2.05-.69-.875-1.335-1.547-1.958-2.056-1.182-.966-2.315-1.303-3.454-1.303zm10.16 2.053c1.147 0 2.188.758 2.992 1.999 1.132 1.748 1.647 4.195 1.647 6.4 0 1.548-.368 2.9-1.839 2.9-.58 0-1.027-.23-1.664-1.004-.496-.601-1.343-1.878-2.832-4.358l-.617-1.028a44.908 44.908 0 0 0-1.255-1.98c.07-.109.141-.224.211-.327 1.12-1.667 2.118-2.602 3.358-2.602zm-10.201.553c1.265 0 2.058.791 2.675 1.446.307.327.737.871 1.234 1.579l-1.02 1.566c-.757 1.163-1.882 3.017-2.837 4.338-1.191 1.649-1.81 1.817-2.486 1.817-.524 0-1.038-.237-1.383-.794-.263-.426-.464-1.13-.464-2.046 0-2.221.63-4.535 1.66-6.088.454-.687.964-1.226 1.533-1.533a2.264 2.264 0 0 1 1.088-.285z"
/>
</svg>
{:else if family === "qwen"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12.604 1.34c.393.69.784 1.382 1.174 2.075a.18.18 0 00.157.091h5.552c.174 0 .322.11.446.327l1.454 2.57c.19.337.24.478.024.837-.26.43-.513.864-.76 1.3l-.367.658c-.106.196-.223.28-.04.512l2.652 4.637c.172.301.111.494-.043.77-.437.785-.882 1.564-1.335 2.34-.159.272-.352.375-.68.37-.777-.016-1.552-.01-2.327.016a.099.099 0 00-.081.05 575.097 575.097 0 01-2.705 4.74c-.169.293-.38.363-.725.364-.997.003-2.002.004-3.017.002a.537.537 0 01-.465-.271l-1.335-2.323a.09.09 0 00-.083-.049H4.982c-.285.03-.553-.001-.805-.092l-1.603-2.77a.543.543 0 01-.002-.54l1.207-2.12a.198.198 0 000-.197 550.951 550.951 0 01-1.875-3.272l-.79-1.395c-.16-.31-.173-.496.095-.965.465-.813.927-1.625 1.387-2.436.132-.234.304-.334.584-.335a338.3 338.3 0 012.589-.001.124.124 0 00.107-.063l2.806-4.895a.488.488 0 01.422-.246c.524-.001 1.053 0 1.583-.006L11.704 1c.341-.003.724.032.9.34zm-3.432.403a.06.06 0 00-.052.03L6.254 6.788a.157.157 0 01-.135.078H3.253c-.056 0-.07.025-.041.074l5.81 10.156c.025.042.013.062-.034.063l-2.795.015a.218.218 0 00-.2.116l-1.32 2.31c-.044.078-.021.118.068.118l5.716.008c.046 0 .08.02.104.061l1.403 2.454c.046.081.092.082.139 0l5.006-8.76.783-1.382a.055.055 0 01.096 0l1.424 2.53a.122.122 0 00.107.062l2.763-.02a.04.04 0 00.035-.02.041.041 0 000-.04l-2.9-5.086a.108.108 0 010-.113l.293-.507 1.12-1.977c.024-.041.012-.062-.035-.062H9.2c-.059 0-.073-.026-.043-.077l1.434-2.505a.107.107 0 000-.114L9.225 1.774a.06.06 0 00-.053-.031zm6.29 8.02c.046 0 .058.02.034.06l-.832 1.465-2.613 4.585a.056.056 0 01-.05.029.058.058 0 01-.05-.029L8.498 9.841c-.02-.034-.01-.052.028-.054l.216-.012 6.722-.012z"
/>
</svg>
{:else if family === "deepseek"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M23.748 4.482c-.254-.124-.364.113-.512.234-.051.039-.094.09-.137.136-.372.397-.806.657-1.373.626-.829-.046-1.537.214-2.163.848-.133-.782-.575-1.248-1.247-1.548-.352-.156-.708-.311-.955-.65-.172-.241-.219-.51-.305-.774-.055-.16-.11-.323-.293-.35-.2-.031-.278.136-.356.276-.313.572-.434 1.202-.422 1.84.027 1.436.633 2.58 1.838 3.393.137.093.172.187.129.323-.082.28-.18.552-.266.833-.055.179-.137.217-.329.14a5.526 5.526 0 01-1.736-1.18c-.857-.828-1.631-1.742-2.597-2.458a11.365 11.365 0 00-.689-.471c-.985-.957.13-1.743.388-1.836.27-.098.093-.432-.779-.428-.872.004-1.67.295-2.687.684a3.055 3.055 0 01-.465.137 9.597 9.597 0 00-2.883-.102c-1.885.21-3.39 1.102-4.497 2.623C.082 8.606-.231 10.684.152 12.85c.403 2.284 1.569 4.175 3.36 5.653 1.858 1.533 3.997 2.284 6.438 2.14 1.482-.085 3.133-.284 4.994-1.86.47.234.962.327 1.78.397.63.059 1.236-.03 1.705-.128.735-.156.684-.837.419-.961-2.155-1.004-1.682-.595-2.113-.926 1.096-1.296 2.746-2.642 3.392-7.003.05-.347.007-.565 0-.845-.004-.17.035-.237.23-.256a4.173 4.173 0 001.545-.475c1.396-.763 1.96-2.015 2.093-3.517.02-.23-.004-.467-.247-.588zM11.581 18c-2.089-1.642-3.102-2.183-3.52-2.16-.392.024-.321.471-.235.763.09.288.207.486.371.739.114.167.192.416-.113.603-.673.416-1.842-.14-1.897-.167-1.361-.802-2.5-1.86-3.301-3.307-.774-1.393-1.224-2.887-1.298-4.482-.02-.386.093-.522.477-.592a4.696 4.696 0 011.529-.039c2.132.312 3.946 1.265 5.468 2.774.868.86 1.525 1.887 2.202 2.891.72 1.066 1.494 2.082 2.48 2.914.348.292.625.514.891.677-.802.09-2.14.11-3.054-.614zm1-6.44a.306.306 0 01.415-.287.302.302 0 01.2.288.306.306 0 01-.31.307.303.303 0 01-.304-.308zm3.11 1.596c-.2.081-.399.151-.59.16a1.245 1.245 0 01-.798-.254c-.274-.23-.47-.358-.552-.758a1.73 1.73 0 01.016-.588c.07-.327-.008-.537-.239-.727-.187-.156-.426-.199-.688-.199a.559.559 0 01-.254-.078c-.11-.054-.2-.19-.114-.358.028-.054.16-.186.192-.21.356-.202.767-.136 1.146.016.352.144.618.408 1.001.782.391.451.462.576.685.914.176.265.336.537.445.848.067.195-.019.354-.25.452z"
/>
</svg>
{:else if family === "openai" || family === "gpt-oss"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M22.2819 9.8211a5.9847 5.9847 0 0 0-.5157-4.9108 6.0462 6.0462 0 0 0-6.5098-2.9A6.0651 6.0651 0 0 0 4.9807 4.1818a5.9847 5.9847 0 0 0-3.9977 2.9 6.0462 6.0462 0 0 0 .7427 7.0966 5.98 5.98 0 0 0 .511 4.9107 6.051 6.051 0 0 0 6.5146 2.9001A5.9847 5.9847 0 0 0 13.2599 24a6.0557 6.0557 0 0 0 5.7718-4.2058 5.9894 5.9894 0 0 0 3.9977-2.9001 6.0557 6.0557 0 0 0-.7475-7.0729zm-9.022 12.6081a4.4755 4.4755 0 0 1-2.8764-1.0408l.1419-.0804 4.7783-2.7582a.7948.7948 0 0 0 .3927-.6813v-6.7369l2.02 1.1686a.071.071 0 0 1 .038.052v5.5826a4.504 4.504 0 0 1-4.4945 4.4944zm-9.6607-4.1254a4.4708 4.4708 0 0 1-.5346-3.0137l.142.0852 4.783 2.7582a.7712.7712 0 0 0 .7806 0l5.8428-3.3685v2.3324a.0804.0804 0 0 1-.0332.0615L9.74 19.9502a4.4992 4.4992 0 0 1-6.1408-1.6464zM2.3408 7.8956a4.485 4.485 0 0 1 2.3655-1.9728V11.6a.7664.7664 0 0 0 .3879.6765l5.8144 3.3543-2.0201 1.1685a.0757.0757 0 0 1-.071 0l-4.8303-2.7865A4.504 4.504 0 0 1 2.3408 7.872zm16.5963 3.8558L13.1038 8.364 15.1192 7.2a.0757.0757 0 0 1 .071 0l4.8303 2.7913a4.4944 4.4944 0 0 1-.6765 8.1042v-5.6772a.79.79 0 0 0-.407-.667zm2.0107-3.0231l-.142-.0852-4.7735-2.7818a.7759.7759 0 0 0-.7854 0L9.409 9.2297V6.8974a.0662.0662 0 0 1 .0284-.0615l4.8303-2.7866a4.4992 4.4992 0 0 1 6.6802 4.66zM8.3065 12.863l-2.02-1.1638a.0804.0804 0 0 1-.038-.0567V6.0742a4.4992 4.4992 0 0 1 7.3757-3.4537l-.142.0805L8.704 5.459a.7948.7948 0 0 0-.3927.6813zm1.0976-2.3654l2.602-1.4998 2.6069 1.4998v2.9994l-2.5974 1.4997-2.6067-1.4997Z"
/>
</svg>
{:else if family === "glm"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M11.991 23.503a.24.24 0 00-.244.248.24.24 0 00.244.249.24.24 0 00.245-.249.24.24 0 00-.22-.247l-.025-.001zM9.671 5.365a1.697 1.697 0 011.099 2.132l-.071.172-.016.04-.018.054c-.07.16-.104.32-.104.498-.035.71.47 1.279 1.186 1.314h.366c1.309.053 2.338 1.173 2.286 2.523-.052 1.332-1.152 2.38-2.478 2.327h-.174c-.715.018-1.274.64-1.239 1.368 0 .124.018.23.053.337.209.373.54.658.96.8.75.23 1.517-.125 1.9-.782l.018-.035c.402-.64 1.17-.96 1.92-.711.854.284 1.378 1.226 1.099 2.167a1.661 1.661 0 01-2.077 1.102 1.711 1.711 0 01-.907-.711l-.017-.035c-.2-.323-.463-.58-.851-.711l-.056-.018a1.646 1.646 0 00-1.954.746 1.66 1.66 0 01-1.065.764 1.677 1.677 0 01-1.989-1.279c-.209-.906.332-1.83 1.257-2.043a1.51 1.51 0 01.296-.035h.018c.68-.071 1.151-.622 1.116-1.333a1.307 1.307 0 00-.227-.693 2.515 2.515 0 01-.366-1.403 2.39 2.39 0 01.366-1.208c.14-.195.21-.444.227-.693.018-.71-.506-1.261-1.186-1.332l-.07-.018a1.43 1.43 0 01-.299-.07l-.05-.019a1.7 1.7 0 01-1.047-2.114 1.68 1.68 0 012.094-1.101zm-5.575 10.11c.26-.264.639-.367.994-.27.355.096.633.379.728.74.095.362-.007.748-.267 1.013-.402.41-1.053.41-1.455 0a1.062 1.062 0 010-1.482zm14.845-.294c.359-.09.738.024.992.297.254.274.344.665.237 1.025-.107.36-.396.634-.756.718-.551.128-1.1-.22-1.23-.781a1.05 1.05 0 01.757-1.26zm-.064-4.39c.314.32.49.753.49 1.206 0 .452-.176.886-.49 1.206-.315.32-.74.5-1.185.5-.444 0-.87-.18-1.184-.5a1.727 1.727 0 010-2.412 1.654 1.654 0 012.369 0zm-11.243.163c.364.484.447 1.128.218 1.691a1.665 1.665 0 01-2.188.923c-.855-.36-1.26-1.358-.907-2.228a1.68 1.68 0 011.33-1.038c.593-.08 1.183.169 1.547.652zm11.545-4.221c.368 0 .708.2.892.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.892.524c-.568 0-1.03-.47-1.03-1.048 0-.579.462-1.048 1.03-1.048zm-14.358 0c.368 0 .707.2.891.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.891.524c-.569 0-1.03-.47-1.03-1.048 0-.579.461-1.048 1.03-1.048zm10.031-1.475c.925 0 1.675.764 1.675 1.706s-.75 1.705-1.675 1.705-1.674-.763-1.674-1.705c0-.942.75-1.706 1.674-1.706zm-2.626-.684c.362-.082.653-.356.761-.718a1.062 1.062 0 00-.238-1.028 1.017 1.017 0 00-.996-.294c-.547.14-.881.7-.752 1.257.13.558.675.907 1.225.783zm0 16.876c.359-.087.644-.36.75-.72a1.062 1.062 0 00-.237-1.019 1.018 1.018 0 00-.985-.301 1.037 1.037 0 00-.762.717c-.108.361-.017.754.239 1.028.245.263.606.377.953.305l.043-.01zM17.19 3.5a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64a.631.631 0 00-.628.64c0 .355.28.64.628.64zm-10.38 0a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64a.631.631 0 00-.628.64c0 .355.279.64.628.64zm-5.182 7.852a.631.631 0 00-.628.64c0 .354.28.639.628.639a.63.63 0 00.627-.606l.001-.034a.62.62 0 00-.628-.64zm5.182 9.13a.631.631 0 00-.628.64c0 .355.279.64.628.64a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm10.38.018a.631.631 0 00-.628.64c0 .355.28.64.628.64a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64zm5.182-9.148a.631.631 0 00-.628.64c0 .354.279.639.628.639a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm-.384-4.992a.24.24 0 00.244-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249c0 .142.122.249.244.249zM11.991.497a.24.24 0 00.245-.248A.24.24 0 0011.99 0a.24.24 0 00-.244.249c0 .133.108.236.223.247l.021.001zM2.011 6.36a.24.24 0 00.245-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249.24.24 0 00.244.249zm0 11.263a.24.24 0 00-.243.248.24.24 0 00.244.249.24.24 0 00.244-.249.252.252 0 00-.244-.248zm19.995-.018a.24.24 0 00-.245.248.24.24 0 00.245.25.24.24 0 00.244-.25.252.252 0 00-.244-.248z"
/>
</svg>
{:else if family === "minimax"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M16.278 2c1.156 0 2.093.927 2.093 2.07v12.501a.74.74 0 00.744.709.74.74 0 00.743-.709V9.099a2.06 2.06 0 012.071-2.049A2.06 2.06 0 0124 9.1v6.561a.649.649 0 01-.652.645.649.649 0 01-.653-.645V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v7.472a2.037 2.037 0 01-2.048 2.026 2.037 2.037 0 01-2.048-2.026v-12.5a.785.785 0 00-.788-.753.785.785 0 00-.789.752l-.001 15.904A2.037 2.037 0 0113.441 22a2.037 2.037 0 01-2.048-2.026V18.04c0-.356.292-.645.652-.645.36 0 .652.289.652.645v1.934c0 .263.142.506.372.638.23.131.514.131.744 0a.734.734 0 00.372-.638V4.07c0-1.143.937-2.07 2.093-2.07zm-5.674 0c1.156 0 2.093.927 2.093 2.07v11.523a.648.648 0 01-.652.645.648.648 0 01-.652-.645V4.07a.785.785 0 00-.789-.78.785.785 0 00-.789.78v14.013a2.06 2.06 0 01-2.07 2.048 2.06 2.06 0 01-2.071-2.048V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v3.8a2.06 2.06 0 01-2.071 2.049A2.06 2.06 0 010 12.9v-1.378c0-.357.292-.646.652-.646.36 0 .653.29.653.646V12.9c0 .418.343.757.766.757s.766-.339.766-.757V9.099a2.06 2.06 0 012.07-2.048 2.06 2.06 0 012.071 2.048v8.984c0 .419.343.758.767.758.423 0 .766-.339.766-.758V4.07c0-1.143.937-2.07 2.093-2.07z"
/>
</svg>
{:else if family === "kimi"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19.738 5.776c.163-.209.306-.4.457-.585.07-.087.064-.153-.004-.244-.655-.861-.717-1.817-.34-2.787.283-.73.909-1.072 1.674-1.145.477-.045.945.004 1.379.236.57.305.902.77 1.01 1.412.086.512.07 1.012-.075 1.508-.257.878-.888 1.333-1.753 1.448-.718.096-1.446.108-2.17.157-.056.004-.113 0-.178 0z"
/>
<path
d="M17.962 1.844h-4.326l-3.425 7.81H5.369V1.878H1.5V22h3.87v-8.477h6.824a3.025 3.025 0 002.743-1.75V22h3.87v-8.477a3.87 3.87 0 00-3.588-3.86v-.01h-2.125a3.94 3.94 0 002.323-2.12l2.545-5.689z"
/>
</svg>
{:else if family === "huggingface"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12.025 1.13c-5.77 0-10.449 4.647-10.449 10.378 0 1.112.178 2.181.503 3.185.064-.222.203-.444.416-.577a.96.96 0 0 1 .524-.15c.293 0 .584.124.84.284.278.173.48.408.71.694.226.282.458.611.684.951v-.014c.017-.324.106-.622.264-.874s.403-.487.762-.543c.3-.047.596.06.787.203s.31.313.4.467c.15.257.212.468.233.542.01.026.653 1.552 1.657 2.54.616.605 1.01 1.223 1.082 1.912.055.537-.096 1.059-.38 1.572.637.121 1.294.187 1.967.187.657 0 1.298-.063 1.921-.178-.287-.517-.44-1.041-.384-1.581.07-.69.465-1.307 1.081-1.913 1.004-.987 1.647-2.513 1.657-2.539.021-.074.083-.285.233-.542.09-.154.208-.323.4-.467a1.08 1.08 0 0 1 .787-.203c.359.056.604.29.762.543s.247.55.265.874v.015c.225-.34.457-.67.683-.952.23-.286.432-.52.71-.694.257-.16.547-.284.84-.285a.97.97 0 0 1 .524.151c.228.143.373.388.43.625l.006.04a10.3 10.3 0 0 0 .534-3.273c0-5.731-4.678-10.378-10.449-10.378M8.327 6.583a1.5 1.5 0 0 1 .713.174 1.487 1.487 0 0 1 .617 2.013c-.183.343-.762-.214-1.102-.094-.38.134-.532.914-.917.71a1.487 1.487 0 0 1 .69-2.803m7.486 0a1.487 1.487 0 0 1 .689 2.803c-.385.204-.536-.576-.916-.71-.34-.12-.92.437-1.103.094a1.487 1.487 0 0 1 .617-2.013 1.5 1.5 0 0 1 .713-.174m-10.68 1.55a.96.96 0 1 1 0 1.921.96.96 0 0 1 0-1.92m13.838 0a.96.96 0 1 1 0 1.92.96.96 0 0 1 0-1.92M8.489 11.458c.588.01 1.965 1.157 3.572 1.164 1.607-.007 2.984-1.155 3.572-1.164.196-.003.305.12.305.454 0 .886-.424 2.328-1.563 3.202-.22-.756-1.396-1.366-1.63-1.32q-.011.001-.02.006l-.044.026-.01.008-.03.024q-.018.017-.035.036l-.032.04a1 1 0 0 0-.058.09l-.014.025q-.049.088-.11.19a1 1 0 0 1-.083.116 1.2 1.2 0 0 1-.173.18q-.035.029-.075.058a1.3 1.3 0 0 1-.251-.243 1 1 0 0 1-.076-.107c-.124-.193-.177-.363-.337-.444-.034-.016-.104-.008-.2.022q-.094.03-.216.087-.06.028-.125.063l-.13.074q-.067.04-.136.086a3 3 0 0 0-.135.096 3 3 0 0 0-.26.219 2 2 0 0 0-.12.121 2 2 0 0 0-.106.128l-.002.002a2 2 0 0 0-.09.132l-.001.001a1.2 1.2 0 0 0-.105.212q-.013.036-.024.073c-1.139-.875-1.563-2.317-1.563-3.203 0-.334.109-.457.305-.454m.836 10.354c.824-1.19.766-2.082-.365-3.194-1.13-1.112-1.789-2.738-1.789-2.738s-.246-.945-.806-.858-.97 1.499.202 2.362c1.173.864-.233 1.45-.685.64-.45-.812-1.683-2.896-2.322-3.295s-1.089-.175-.938.647 2.822 2.813 2.562 3.244-1.176-.506-1.176-.506-2.866-2.567-3.49-1.898.473 1.23 2.037 2.16c1.564.932 1.686 1.178 1.464 1.53s-3.675-2.511-4-1.297c-.323 1.214 3.524 1.567 3.287 2.405-.238.839-2.71-1.587-3.216-.642-.506.946 3.49 2.056 3.522 2.064 1.29.33 4.568 1.028 5.713-.624m5.349 0c-.824-1.19-.766-2.082.365-3.194 1.13-1.112 1.789-2.738 1.789-2.738s.246-.945.806-.858.97 1.499-.202 2.362c-1.173.864.233 1.45.685.64.451-.812 1.683-2.896 2.322-3.295s1.089-.175.938.647-2.822 2.813-2.562 3.244 1.176-.506 1.176-.506 2.866-2.567 3.49-1.898-.473 1.23-2.037 2.16c-1.564.932-1.686 1.178-1.464 1.53s3.675-2.511 4-1.297c.323 1.214-3.524 1.567-3.287 2.405.238.839 2.71-1.587 3.216-.642.506.946-3.49 2.056-3.522 2.064-1.29.33-4.568 1.028-5.713-.624"
/>
</svg>
{:else}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z"
/>
</svg>
{/if}

View File

@@ -0,0 +1,142 @@
<script lang="ts">
import FamilyLogos from "./FamilyLogos.svelte";
type FamilySidebarProps = {
families: string[];
selectedFamily: string | null;
hasFavorites: boolean;
onSelect: (family: string | null) => void;
};
let { families, selectedFamily, hasFavorites, onSelect }: FamilySidebarProps =
$props();
// Family display names
const familyNames: Record<string, string> = {
favorites: "Favorites",
huggingface: "Hub",
llama: "Meta",
qwen: "Qwen",
deepseek: "DeepSeek",
"gpt-oss": "OpenAI",
glm: "GLM",
minimax: "MiniMax",
kimi: "Kimi",
};
function getFamilyName(family: string): string {
return (
familyNames[family] || family.charAt(0).toUpperCase() + family.slice(1)
);
}
</script>
<div
class="flex flex-col gap-1 py-2 px-1 border-r border-exo-yellow/10 bg-exo-medium-gray/30 min-w-[64px]"
>
<!-- All models (no filter) -->
<button
type="button"
onclick={() => onSelect(null)}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
null
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="All models"
>
<svg
class="w-5 h-5 {selectedFamily === null
? 'text-exo-yellow'
: 'text-white/50 group-hover:text-white/70'}"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M4 8h4V4H4v4zm6 12h4v-4h-4v4zm-6 0h4v-4H4v4zm0-6h4v-4H4v4zm6 0h4v-4h-4v4zm6-10v4h4V4h-4zm-6 4h4V4h-4v4zm6 6h4v-4h-4v4zm0 6h4v-4h-4v4z"
/>
</svg>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === null
? 'text-exo-yellow'
: 'text-white/40 group-hover:text-white/60'}">All</span
>
</button>
<!-- Favorites (only show if has favorites) -->
{#if hasFavorites}
<button
type="button"
onclick={() => onSelect("favorites")}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
'favorites'
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="Show favorited models"
>
<FamilyLogos
family="favorites"
class={selectedFamily === "favorites"
? "text-amber-400"
: "text-white/50 group-hover:text-amber-400/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'favorites'
? 'text-amber-400'
: 'text-white/40 group-hover:text-white/60'}">Faves</span
>
</button>
{/if}
<!-- HuggingFace Hub -->
<button
type="button"
onclick={() => onSelect("huggingface")}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
'huggingface'
? 'bg-orange-500/20 border-l-2 border-orange-400'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="Browse and add models from Hugging Face"
>
<FamilyLogos
family="huggingface"
class={selectedFamily === "huggingface"
? "text-orange-400"
: "text-white/50 group-hover:text-orange-400/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'huggingface'
? 'text-orange-400'
: 'text-white/40 group-hover:text-white/60'}">Hub</span
>
</button>
<div class="h-px bg-exo-yellow/10 my-1"></div>
<!-- Model families -->
{#each families as family}
<button
type="button"
onclick={() => onSelect(family)}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
family
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title={getFamilyName(family)}
>
<FamilyLogos
{family}
class={selectedFamily === family
? "text-exo-yellow"
: "text-white/50 group-hover:text-white/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 truncate max-w-full {selectedFamily ===
family
? 'text-exo-yellow'
: 'text-white/40 group-hover:text-white/60'}"
>
{getFamilyName(family)}
</span>
</button>
{/each}
</div>

View File

@@ -0,0 +1,127 @@
<script lang="ts">
interface HuggingFaceModel {
id: string;
author: string;
downloads: number;
likes: number;
last_modified: string;
tags: string[];
}
type HuggingFaceResultItemProps = {
model: HuggingFaceModel;
isAdded: boolean;
isAdding: boolean;
onAdd: () => void;
onSelect: () => void;
};
let {
model,
isAdded,
isAdding,
onAdd,
onSelect,
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
if (num >= 1000000) {
return `${(num / 1000000).toFixed(1)}M`;
} else if (num >= 1000) {
return `${(num / 1000).toFixed(1)}k`;
}
return num.toString();
}
// Extract model name from full ID (e.g., "mlx-community/Llama-3.2-1B" -> "Llama-3.2-1B")
const modelName = $derived(model.id.split("/").pop() || model.id);
</script>
<div
class="flex items-center justify-between gap-3 px-3 py-2.5 hover:bg-white/5 transition-colors border-b border-white/5 last:border-b-0"
>
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="text-sm font-mono text-white truncate" title={model.id}
>{modelName}</span
>
{#if isAdded}
<span
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"
>Added</span
>
{/if}
</div>
<div class="flex items-center gap-3 mt-0.5 text-xs text-white/40">
<span class="truncate">{model.author}</span>
<span
class="flex items-center gap-1 shrink-0"
title="Downloads in the last 30 days"
>
<svg
class="w-3 h-3"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"
/>
</svg>
{formatNumber(model.downloads)}
</span>
<span
class="flex items-center gap-1 shrink-0"
title="Community likes on Hugging Face"
>
<svg
class="w-3 h-3"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4.318 6.318a4.5 4.5 0 000 6.364L12 20.364l7.682-7.682a4.5 4.5 0 00-6.364-6.364L12 7.636l-1.318-1.318a4.5 4.5 0 00-6.364 0z"
/>
</svg>
{formatNumber(model.likes)}
</span>
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
{#if isAdded}
<button
type="button"
onclick={onSelect}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-yellow/10 text-exo-yellow border border-exo-yellow/30 hover:bg-exo-yellow/20 transition-colors rounded cursor-pointer"
>
Select
</button>
{:else}
<button
type="button"
onclick={onAdd}
disabled={isAdding}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded cursor-pointer disabled:opacity-50 disabled:cursor-not-allowed"
>
{#if isAdding}
<span class="flex items-center gap-1.5">
<span
class="w-3 h-3 border-2 border-orange-400 border-t-transparent rounded-full animate-spin"
></span>
Adding...
</span>
{:else}
+ Add
{/if}
</button>
{/if}
</div>
</div>

View File

@@ -0,0 +1,182 @@
<script lang="ts">
import { fly } from "svelte/transition";
import { cubicOut } from "svelte/easing";
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
}
type ModelFilterPopoverProps = {
filters: FilterState;
onChange: (filters: FilterState) => void;
onClear: () => void;
onClose: () => void;
};
let { filters, onChange, onClear, onClose }: ModelFilterPopoverProps =
$props();
// Available capabilities
const availableCapabilities = [
{ id: "text", label: "Text" },
{ id: "thinking", label: "Thinking" },
{ id: "code", label: "Code" },
{ id: "vision", label: "Vision" },
];
// Size ranges
const sizeRanges = [
{ label: "< 10GB", min: 0, max: 10 },
{ label: "10-50GB", min: 10, max: 50 },
{ label: "50-200GB", min: 50, max: 200 },
{ label: "> 200GB", min: 200, max: 10000 },
];
function toggleCapability(cap: string) {
const next = filters.capabilities.includes(cap)
? filters.capabilities.filter((c) => c !== cap)
: [...filters.capabilities, cap];
onChange({ ...filters, capabilities: next });
}
function selectSizeRange(range: { min: number; max: number } | null) {
// Toggle off if same range is clicked
if (
filters.sizeRange &&
range &&
filters.sizeRange.min === range.min &&
filters.sizeRange.max === range.max
) {
onChange({ ...filters, sizeRange: null });
} else {
onChange({ ...filters, sizeRange: range });
}
}
function handleClickOutside(e: MouseEvent) {
const target = e.target as HTMLElement;
if (
!target.closest(".filter-popover") &&
!target.closest(".filter-toggle")
) {
onClose();
}
}
</script>
<svelte:window onclick={handleClickOutside} />
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="filter-popover absolute right-0 top-full mt-2 w-64 bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-xl z-10"
transition:fly={{ y: -10, duration: 200, easing: cubicOut }}
onclick={(e) => e.stopPropagation()}
role="dialog"
aria-label="Filter options"
>
<div class="p-3 space-y-4">
<!-- Capabilities -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Capabilities</h4>
<div class="flex flex-wrap gap-1.5">
{#each availableCapabilities as cap}
{@const isSelected = filters.capabilities.includes(cap.id)}
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {isSelected
? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() => toggleCapability(cap.id)}
>
{#if cap.id === "text"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "thinking"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "code"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M16 18l6-6-6-6M8 6l-6 6 6 6"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "vision"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z"
stroke-linecap="round"
stroke-linejoin="round"
/><circle cx="12" cy="12" r="3" /></svg
>
{/if}
<span class="ml-1">{cap.label}</span>
</button>
{/each}
</div>
</div>
<!-- Size range -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>
<div class="flex flex-wrap gap-1.5">
{#each sizeRanges as range}
{@const isSelected =
filters.sizeRange &&
filters.sizeRange.min === range.min &&
filters.sizeRange.max === range.max}
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {isSelected
? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() => selectSizeRange(range)}
>
{range.label}
</button>
{/each}
</div>
</div>
<!-- Clear button -->
<button
type="button"
class="w-full py-1.5 text-xs font-mono text-white/50 hover:text-white/70 hover:bg-white/5 rounded transition-colors"
onclick={onClear}
>
Clear all filters
</button>
</div>
</div>

View File

@@ -0,0 +1,357 @@
<script lang="ts">
interface ModelInfo {
id: string;
name?: string;
storage_size_megabytes?: number;
base_model?: string;
quantization?: string;
supports_tensor?: boolean;
capabilities?: string[];
family?: string;
is_custom?: boolean;
}
interface ModelGroup {
id: string;
name: string;
capabilities: string[];
family: string;
variants: ModelInfo[];
smallestVariant: ModelInfo;
hasMultipleVariants: boolean;
}
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
type ModelPickerGroupProps = {
group: ModelGroup;
isExpanded: boolean;
isFavorite: boolean;
selectedModelId: string | null;
canModelFit: (id: string) => boolean;
onToggleExpand: () => void;
onSelectModel: (modelId: string) => void;
onToggleFavorite: (baseModelId: string) => void;
onShowInfo: (group: ModelGroup) => void;
downloadStatus?: DownloadAvailability;
};
let {
group,
isExpanded,
isFavorite,
selectedModelId,
canModelFit,
onToggleExpand,
onSelectModel,
onToggleFavorite,
onShowInfo,
downloadStatus,
}: ModelPickerGroupProps = $props();
// Format storage size
function formatSize(mb: number | undefined): string {
if (!mb) return "";
if (mb >= 1024) {
return `${(mb / 1024).toFixed(0)}GB`;
}
return `${mb}MB`;
}
// Check if any variant can fit
const anyVariantFits = $derived(
group.variants.some((v) => canModelFit(v.id)),
);
// Check if this group's model is currently selected (for single-variant groups)
const isMainSelected = $derived(
!group.hasMultipleVariants &&
group.variants.some((v) => v.id === selectedModelId),
);
</script>
<div
class="border-b border-white/5 last:border-b-0 {!anyVariantFits
? 'opacity-50'
: ''}"
>
<!-- Main row -->
<div
class="flex items-center gap-2 px-3 py-2.5 transition-colors {anyVariantFits
? 'hover:bg-white/5 cursor-pointer'
: 'cursor-not-allowed'} {isMainSelected
? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'
: 'border-l-2 border-transparent'}"
onclick={() => {
if (group.hasMultipleVariants) {
onToggleExpand();
} else {
const modelId = group.variants[0]?.id;
if (modelId && canModelFit(modelId)) {
onSelectModel(modelId);
}
}
}}
role="button"
tabindex="0"
onkeydown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
if (group.hasMultipleVariants) {
onToggleExpand();
} else {
const modelId = group.variants[0]?.id;
if (modelId && canModelFit(modelId)) {
onSelectModel(modelId);
}
}
}
}}
>
<!-- Expand/collapse chevron (for groups with variants) -->
{#if group.hasMultipleVariants}
<svg
class="w-4 h-4 text-white/40 transition-transform duration-200 flex-shrink-0 {isExpanded
? 'rotate-90'
: ''}"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M8.59 16.59L13.17 12 8.59 7.41 10 6l6 6-6 6-1.41-1.41z" />
</svg>
{:else}
<div class="w-4 flex-shrink-0"></div>
{/if}
<!-- Model name -->
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="font-mono text-sm text-white truncate">
{group.name}
</span>
<!-- Capability icons -->
{#each group.capabilities.filter((c) => c !== "text") as cap}
{#if cap === "thinking"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports Thinking"
>
<path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
{:else if cap === "code"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports code generation"
>
<path
d="M16 18l6-6-6-6M8 6l-6 6 6 6"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
{:else if cap === "vision"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports image input"
>
<path
d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z"
stroke-linecap="round"
stroke-linejoin="round"
/>
<circle cx="12" cy="12" r="3" />
</svg>
{:else if cap === "image_gen"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports image generation"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<path d="M21 15l-5-5L5 21" />
</svg>
{/if}
{/each}
</div>
</div>
<!-- Size indicator (smallest variant) -->
{#if !group.hasMultipleVariants && group.smallestVariant?.storage_size_megabytes}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{formatSize(group.smallestVariant.storage_size_megabytes)}
</span>
{/if}
<!-- Variant count -->
{#if group.hasMultipleVariants}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{group.variants.length} variants
</span>
{/if}
<!-- Download availability indicator -->
{#if downloadStatus && downloadStatus.nodeIds.length > 0}
<span
class="flex-shrink-0"
title={downloadStatus.available
? `Ready — downloaded on ${downloadStatus.nodeNames.join(", ")}`
: `Downloaded on ${downloadStatus.nodeNames.join(", ")} (may need more nodes)`}
>
<svg
class="w-4 h-4 {downloadStatus.available
? 'text-green-400'
: 'text-green-400/40'}"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path d="M22 11.08V12a10 10 0 1 1-5.93-9.14" />
<polyline points="22 4 12 14.01 9 11.01" />
</svg>
</span>
{/if}
<!-- Check mark if selected (single-variant) -->
{#if isMainSelected}
<svg
class="w-4 h-4 text-exo-yellow flex-shrink-0"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z" />
</svg>
{/if}
<!-- Favorite star -->
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0"
onclick={(e) => {
e.stopPropagation();
onToggleFavorite(group.id);
}}
title={isFavorite ? "Remove from favorites" : "Add to favorites"}
>
{#if isFavorite}
<svg
class="w-4 h-4 text-amber-400"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{:else}
<svg
class="w-4 h-4 text-white/30 hover:text-white/50"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{/if}
</button>
<!-- Info button -->
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0"
onclick={(e) => {
e.stopPropagation();
onShowInfo(group);
}}
title="View model details"
>
<svg
class="w-4 h-4 text-white/30 hover:text-white/50"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-6h2v6zm0-8h-2V7h2v2z"
/>
</svg>
</button>
</div>
<!-- Expanded variants -->
{#if isExpanded && group.hasMultipleVariants}
<div class="bg-black/20 border-t border-white/5">
{#each group.variants as variant}
{@const modelCanFit = canModelFit(variant.id)}
{@const isSelected = selectedModelId === variant.id}
<button
type="button"
class="w-full flex items-center gap-3 px-3 py-2 pl-10 hover:bg-white/5 transition-colors text-left {!modelCanFit
? 'opacity-50 cursor-not-allowed'
: 'cursor-pointer'} {isSelected
? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'
: 'border-l-2 border-transparent'}"
disabled={!modelCanFit}
onclick={() => {
if (modelCanFit) {
onSelectModel(variant.id);
}
}}
>
<!-- Quantization badge -->
<span
class="text-xs font-mono px-1.5 py-0.5 rounded bg-white/10 text-white/70 flex-shrink-0"
>
{variant.quantization || "default"}
</span>
<!-- Size -->
<span class="text-xs font-mono text-white/40 flex-1">
{formatSize(variant.storage_size_megabytes)}
</span>
<!-- Check mark if selected -->
{#if isSelected}
<svg
class="w-4 h-4 text-exo-yellow"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z"
/>
</svg>
{/if}
</button>
{/each}
</div>
{/if}
</div>

View File

@@ -0,0 +1,842 @@
<script lang="ts">
import { fade, fly } from "svelte/transition";
import { cubicOut } from "svelte/easing";
import FamilySidebar from "./FamilySidebar.svelte";
import ModelPickerGroup from "./ModelPickerGroup.svelte";
import ModelFilterPopover from "./ModelFilterPopover.svelte";
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
interface ModelInfo {
id: string;
name?: string;
storage_size_megabytes?: number;
base_model?: string;
quantization?: string;
supports_tensor?: boolean;
capabilities?: string[];
family?: string;
is_custom?: boolean;
tasks?: string[];
hugging_face_id?: string;
}
interface ModelGroup {
id: string;
name: string;
capabilities: string[];
family: string;
variants: ModelInfo[];
smallestVariant: ModelInfo;
hasMultipleVariants: boolean;
}
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
}
interface HuggingFaceModel {
id: string;
author: string;
downloads: number;
likes: number;
last_modified: string;
tags: string[];
}
type ModelPickerModalProps = {
isOpen: boolean;
models: ModelInfo[];
selectedModelId: string | null;
favorites: Set<string>;
existingModelIds: Set<string>;
canModelFit: (modelId: string) => boolean;
onSelect: (modelId: string) => void;
onClose: () => void;
onToggleFavorite: (baseModelId: string) => void;
onAddModel: (modelId: string) => Promise<void>;
onDeleteModel: (modelId: string) => Promise<void>;
totalMemoryGB: number;
usedMemoryGB: number;
downloadsData?: Record<string, unknown[]>;
topologyNodes?: Record<
string,
{
friendly_name?: string;
system_info?: { model_id?: string };
macmon_info?: { memory?: { ram_total?: number } };
}
>;
};
let {
isOpen,
models,
selectedModelId,
favorites,
existingModelIds,
canModelFit,
onSelect,
onClose,
onToggleFavorite,
onAddModel,
onDeleteModel,
totalMemoryGB,
usedMemoryGB,
downloadsData,
topologyNodes,
}: ModelPickerModalProps = $props();
// Local state
let searchQuery = $state("");
let selectedFamily = $state<string | null>(null);
let expandedGroups = $state<Set<string>>(new Set());
let showFilters = $state(false);
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
let infoGroup = $state<ModelGroup | null>(null);
// Download availability per model group
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
function getNodeName(nodeId: string): string {
const node = topologyNodes?.[nodeId];
return (
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
);
}
const modelDownloadAvailability = $derived.by(() => {
const result = new Map<string, DownloadAvailability>();
if (!downloadsData || !topologyNodes) return result;
for (const model of models) {
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
if (nodeIds.length === 0) continue;
// Sum total RAM across nodes that have the model
let totalRamBytes = 0;
for (const nodeId of nodeIds) {
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
}
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
result.set(model.id, {
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
nodeNames: nodeIds.map(getNodeName),
nodeIds,
});
}
return result;
});
// Aggregate download availability per group (available if ANY variant is available)
function getGroupDownloadAvailability(
group: ModelGroup,
): DownloadAvailability | undefined {
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) return avail;
}
return undefined;
}
// HuggingFace Hub state
let hfSearchQuery = $state("");
let hfSearchResults = $state<HuggingFaceModel[]>([]);
let hfTrendingModels = $state<HuggingFaceModel[]>([]);
let hfIsSearching = $state(false);
let hfIsLoadingTrending = $state(false);
let addingModelId = $state<string | null>(null);
let hfSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;
let manualModelId = $state("");
let addModelError = $state<string | null>(null);
// Reset state when modal opens
$effect(() => {
if (isOpen) {
searchQuery = "";
selectedFamily = null;
expandedGroups = new Set();
showFilters = false;
hfSearchQuery = "";
hfSearchResults = [];
manualModelId = "";
addModelError = null;
}
});
// Fetch trending models when HuggingFace is selected
$effect(() => {
if (
selectedFamily === "huggingface" &&
hfTrendingModels.length === 0 &&
!hfIsLoadingTrending
) {
fetchTrendingModels();
}
});
async function fetchTrendingModels() {
hfIsLoadingTrending = true;
try {
const response = await fetch("/models/search?query=&limit=20");
if (response.ok) {
hfTrendingModels = await response.json();
}
} catch (error) {
console.error("Failed to fetch trending models:", error);
} finally {
hfIsLoadingTrending = false;
}
}
async function searchHuggingFace(query: string) {
if (query.length < 2) {
hfSearchResults = [];
return;
}
hfIsSearching = true;
try {
const response = await fetch(
`/models/search?query=${encodeURIComponent(query)}&limit=20`,
);
if (response.ok) {
hfSearchResults = await response.json();
} else {
hfSearchResults = [];
}
} catch (error) {
console.error("Failed to search models:", error);
hfSearchResults = [];
} finally {
hfIsSearching = false;
}
}
function handleHfSearchInput(query: string) {
hfSearchQuery = query;
addModelError = null;
if (hfSearchDebounceTimer) {
clearTimeout(hfSearchDebounceTimer);
}
if (query.length >= 2) {
hfSearchDebounceTimer = setTimeout(() => {
searchHuggingFace(query);
}, 300);
} else {
hfSearchResults = [];
}
}
async function handleAddModel(modelId: string) {
addingModelId = modelId;
addModelError = null;
try {
await onAddModel(modelId);
} catch (error) {
addModelError =
error instanceof Error ? error.message : "Failed to add model";
} finally {
addingModelId = null;
}
}
async function handleAddManualModel() {
if (!manualModelId.trim()) return;
await handleAddModel(manualModelId.trim());
if (!addModelError) {
manualModelId = "";
}
}
function handleSelectHfModel(modelId: string) {
onSelect(modelId);
onClose();
}
// Models to display in HuggingFace view
const hfDisplayModels = $derived.by((): HuggingFaceModel[] => {
if (hfSearchQuery.length >= 2) {
return hfSearchResults;
}
return hfTrendingModels;
});
// Group models by base_model
const groupedModels = $derived.by((): ModelGroup[] => {
const groups = new Map<string, ModelGroup>();
for (const model of models) {
const groupId = model.base_model || model.id;
const groupName = model.base_model || model.name || model.id;
if (!groups.has(groupId)) {
groups.set(groupId, {
id: groupId,
name: groupName,
capabilities: model.capabilities || ["text"],
family: model.family || "",
variants: [],
smallestVariant: model,
hasMultipleVariants: false,
});
}
const group = groups.get(groupId)!;
group.variants.push(model);
// Track smallest variant
if (
(model.storage_size_megabytes || 0) <
(group.smallestVariant.storage_size_megabytes || Infinity)
) {
group.smallestVariant = model;
}
// Update capabilities if not set
if (
group.capabilities.length <= 1 &&
model.capabilities &&
model.capabilities.length > 1
) {
group.capabilities = model.capabilities;
}
if (!group.family && model.family) {
group.family = model.family;
}
}
// Sort variants within each group by size
for (const group of groups.values()) {
group.variants.sort(
(a, b) =>
(a.storage_size_megabytes || 0) - (b.storage_size_megabytes || 0),
);
group.hasMultipleVariants = group.variants.length > 1;
}
// Convert to array and sort by smallest variant size (biggest first)
return Array.from(groups.values()).sort((a, b) => {
return (
(b.smallestVariant.storage_size_megabytes || 0) -
(a.smallestVariant.storage_size_megabytes || 0)
);
});
});
// Get unique families
const uniqueFamilies = $derived.by((): string[] => {
const families = new Set<string>();
for (const group of groupedModels) {
if (group.family) {
families.add(group.family);
}
}
const familyOrder = [
"kimi",
"qwen",
"glm",
"minimax",
"deepseek",
"gpt-oss",
"llama",
];
return Array.from(families).sort((a, b) => {
const aIdx = familyOrder.indexOf(a);
const bIdx = familyOrder.indexOf(b);
if (aIdx === -1 && bIdx === -1) return a.localeCompare(b);
if (aIdx === -1) return 1;
if (bIdx === -1) return -1;
return aIdx - bIdx;
});
});
// Filter models based on search, family, and filters
const filteredGroups = $derived.by((): ModelGroup[] => {
let result: ModelGroup[] = [...groupedModels];
// Filter by family
if (selectedFamily === "favorites") {
result = result.filter((g) => favorites.has(g.id));
} else if (selectedFamily && selectedFamily !== "huggingface") {
result = result.filter((g) => g.family === selectedFamily);
}
// Filter by search query
if (searchQuery.trim()) {
const query = searchQuery.toLowerCase().trim();
result = result.filter(
(g) =>
g.name.toLowerCase().includes(query) ||
g.variants.some(
(v) =>
v.id.toLowerCase().includes(query) ||
(v.name || "").toLowerCase().includes(query),
),
);
}
// Filter by capabilities
if (filters.capabilities.length > 0) {
result = result.filter((g) =>
filters.capabilities.every((cap) => g.capabilities.includes(cap)),
);
}
// Filter by size range
if (filters.sizeRange) {
const { min, max } = filters.sizeRange;
result = result.filter((g) => {
const sizeGB = (g.smallestVariant.storage_size_megabytes || 0) / 1024;
return sizeGB >= min && sizeGB <= max;
});
}
// Sort: models that fit first, then by size (largest first)
result.sort((a, b) => {
const aFits = a.variants.some((v) => canModelFit(v.id));
const bFits = b.variants.some((v) => canModelFit(v.id));
if (aFits && !bFits) return -1;
if (!aFits && bFits) return 1;
return (
(b.smallestVariant.storage_size_megabytes || 0) -
(a.smallestVariant.storage_size_megabytes || 0)
);
});
return result;
});
// Check if any favorites exist
const hasFavorites = $derived(favorites.size > 0);
function toggleGroupExpanded(groupId: string) {
const next = new Set(expandedGroups);
if (next.has(groupId)) {
next.delete(groupId);
} else {
next.add(groupId);
}
expandedGroups = next;
}
function handleSelect(modelId: string) {
onSelect(modelId);
onClose();
}
function handleKeydown(e: KeyboardEvent) {
if (e.key === "Escape") {
onClose();
}
}
function handleFiltersChange(newFilters: FilterState) {
filters = newFilters;
}
function clearFilters() {
filters = { capabilities: [], sizeRange: null };
}
const hasActiveFilters = $derived(
filters.capabilities.length > 0 || filters.sizeRange !== null,
);
</script>
<svelte:window onkeydown={handleKeydown} />
{#if isOpen}
<!-- Backdrop -->
<div
class="fixed inset-0 z-50 bg-black/80 backdrop-blur-sm"
transition:fade={{ duration: 200 }}
onclick={onClose}
role="presentation"
></div>
<!-- Modal -->
<div
class="fixed z-50 top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(90vw,600px)] h-[min(80vh,700px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl overflow-hidden flex flex-col"
transition:fly={{ y: 20, duration: 300, easing: cubicOut }}
role="dialog"
aria-modal="true"
aria-label="Select a model"
>
<!-- Header with search -->
<div
class="flex items-center gap-2 p-3 border-b border-exo-yellow/10 bg-exo-medium-gray/30"
>
{#if selectedFamily === "huggingface"}
<!-- HuggingFace search -->
<svg
class="w-5 h-5 text-orange-400/60 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<circle cx="11" cy="11" r="8" />
<path d="M21 21l-4.35-4.35" />
</svg>
<input
type="search"
class="flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40"
placeholder="Search mlx-community models..."
value={hfSearchQuery}
oninput={(e) => handleHfSearchInput(e.currentTarget.value)}
/>
{#if hfIsSearching}
<div class="flex-shrink-0">
<span
class="w-4 h-4 border-2 border-orange-400 border-t-transparent rounded-full animate-spin block"
></span>
</div>
{/if}
{:else}
<!-- Normal model search -->
<svg
class="w-5 h-5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<circle cx="11" cy="11" r="8" />
<path d="M21 21l-4.35-4.35" />
</svg>
<input
type="search"
class="flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40"
placeholder="Search models..."
bind:value={searchQuery}
/>
<!-- Cluster memory -->
<span
class="text-xs font-mono flex-shrink-0"
title="Cluster memory usage"
><span class="text-exo-yellow">{Math.round(usedMemoryGB)}GB</span
><span class="text-white/40">/{Math.round(totalMemoryGB)}GB</span
></span
>
<!-- Filter button -->
<div class="relative filter-toggle">
<button
type="button"
class="p-1.5 rounded hover:bg-white/10 transition-colors {hasActiveFilters
? 'text-exo-yellow'
: 'text-white/50'}"
onclick={() => (showFilters = !showFilters)}
title="Filter by capability or size"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="currentColor">
<path d="M10 18h4v-2h-4v2zM3 6v2h18V6H3zm3 7h12v-2H6v2z" />
</svg>
</button>
{#if showFilters}
<ModelFilterPopover
{filters}
onChange={handleFiltersChange}
onClear={clearFilters}
onClose={() => (showFilters = false)}
/>
{/if}
</div>
{/if}
<!-- Close button -->
<button
type="button"
class="p-1.5 rounded hover:bg-white/10 transition-colors text-white/50 hover:text-white/70"
onclick={onClose}
title="Close model picker"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z"
/>
</svg>
</button>
</div>
<!-- Body -->
<div class="flex flex-1 overflow-hidden">
<!-- Family sidebar -->
<FamilySidebar
families={uniqueFamilies}
{selectedFamily}
{hasFavorites}
onSelect={(family) => (selectedFamily = family)}
/>
<!-- Model list -->
<div class="flex-1 overflow-y-auto flex flex-col">
{#if selectedFamily === "huggingface"}
<!-- HuggingFace Hub view -->
<div class="flex-1 flex flex-col min-h-0">
<!-- Section header -->
<div
class="sticky top-0 z-10 px-3 py-2 bg-exo-dark-gray/95 border-b border-exo-yellow/10"
>
<span class="text-xs font-mono text-white/40">
{#if hfSearchQuery.length >= 2}
Search results for "{hfSearchQuery}"
{:else}
Trending on mlx-community
{/if}
</span>
</div>
<!-- Results list -->
<div class="flex-1 overflow-y-auto">
{#if hfIsLoadingTrending && hfTrendingModels.length === 0}
<div
class="flex items-center justify-center py-12 text-white/40"
>
<span
class="w-5 h-5 border-2 border-orange-400 border-t-transparent rounded-full animate-spin mr-2"
></span>
<span class="font-mono text-sm"
>Loading trending models...</span
>
</div>
{:else if hfDisplayModels.length === 0}
<div
class="flex flex-col items-center justify-center py-12 text-white/40"
>
<svg
class="w-10 h-10 mb-2"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 13.5c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm4 0c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm2-4.5H8c0-2.21 1.79-4 4-4s4 1.79 4 4z"
/>
</svg>
<p class="font-mono text-sm">No models found</p>
{#if hfSearchQuery}
<p class="font-mono text-xs mt-1">
Try a different search term
</p>
{/if}
</div>
{:else}
{#each hfDisplayModels as model}
<HuggingFaceResultItem
{model}
isAdded={existingModelIds.has(model.id)}
isAdding={addingModelId === model.id}
onAdd={() => handleAddModel(model.id)}
onSelect={() => handleSelectHfModel(model.id)}
/>
{/each}
{/if}
</div>
<!-- Manual input footer -->
<div
class="sticky bottom-0 border-t border-exo-yellow/10 bg-exo-dark-gray p-3"
>
{#if addModelError}
<div
class="bg-red-500/10 border border-red-500/30 rounded px-3 py-2 mb-2"
>
<p class="text-red-400 text-xs font-mono break-words">
{addModelError}
</p>
</div>
{/if}
<div class="flex gap-2">
<input
type="text"
class="flex-1 bg-exo-black/60 border border-exo-yellow/30 rounded px-3 py-1.5 text-xs font-mono text-white placeholder-white/30 focus:outline-none focus:border-exo-yellow/50"
placeholder="Or paste model ID directly..."
bind:value={manualModelId}
onkeydown={(e) => {
if (e.key === "Enter") handleAddManualModel();
}}
/>
<button
type="button"
onclick={handleAddManualModel}
disabled={!manualModelId.trim() || addingModelId !== null}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded disabled:opacity-50 disabled:cursor-not-allowed"
>
Add
</button>
</div>
</div>
</div>
{:else if filteredGroups.length === 0}
<div
class="flex flex-col items-center justify-center h-full text-white/40 p-8"
>
<svg class="w-12 h-12 mb-3" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z"
/>
</svg>
<p class="font-mono text-sm">No models found</p>
{#if hasActiveFilters || searchQuery}
<button
type="button"
class="mt-2 text-xs text-exo-yellow hover:underline"
onclick={() => {
searchQuery = "";
clearFilters();
}}
>
Clear filters
</button>
{/if}
</div>
{:else}
{#each filteredGroups as group}
<ModelPickerGroup
{group}
isExpanded={expandedGroups.has(group.id)}
isFavorite={favorites.has(group.id)}
{selectedModelId}
{canModelFit}
onToggleExpand={() => toggleGroupExpanded(group.id)}
onSelectModel={handleSelect}
{onToggleFavorite}
onShowInfo={(g) => (infoGroup = g)}
downloadStatus={getGroupDownloadAvailability(group)}
/>
{/each}
{/if}
</div>
</div>
<!-- Footer with active filters indicator -->
{#if hasActiveFilters}
<div
class="flex items-center gap-2 px-3 py-2 border-t border-exo-yellow/10 bg-exo-medium-gray/20 text-xs font-mono text-white/50"
>
<span>Filters:</span>
{#each filters.capabilities as cap}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded"
>{cap}</span
>
{/each}
{#if filters.sizeRange}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
</span>
{/if}
<button
type="button"
class="ml-auto text-white/40 hover:text-white/60"
onclick={clearFilters}
>
Clear all
</button>
</div>
{/if}
</div>
<!-- Info modal -->
{#if infoGroup}
<div
class="fixed inset-0 z-[60] bg-black/60"
transition:fade={{ duration: 150 }}
onclick={() => (infoGroup = null)}
role="presentation"
></div>
<div
class="fixed z-[60] top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(80vw,400px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl p-4"
transition:fly={{ y: 10, duration: 200, easing: cubicOut }}
role="dialog"
aria-modal="true"
>
<div class="flex items-start justify-between mb-3">
<h3 class="font-mono text-lg text-white">{infoGroup.name}</h3>
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors text-white/50"
onclick={() => (infoGroup = null)}
title="Close model details"
aria-label="Close info dialog"
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z"
/>
</svg>
</button>
</div>
<div class="space-y-2 text-xs font-mono">
<div class="flex items-center gap-2">
<span class="text-white/40">Family:</span>
<span class="text-white/70">{infoGroup.family || "Unknown"}</span>
</div>
<div class="flex items-center gap-2">
<span class="text-white/40">Capabilities:</span>
<span class="text-white/70">{infoGroup.capabilities.join(", ")}</span>
</div>
<div class="flex items-center gap-2">
<span class="text-white/40">Variants:</span>
<span class="text-white/70">{infoGroup.variants.length}</span>
</div>
{#if infoGroup.variants.length > 0}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<span class="text-white/40">Available quantizations:</span>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoGroup.variants as variant}
<span
class="px-1.5 py-0.5 bg-white/10 text-white/60 rounded text-[10px]"
>
{variant.quantization || "default"} ({Math.round(
(variant.storage_size_megabytes || 0) / 1024,
)}GB)
</span>
{/each}
</div>
</div>
{/if}
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
{#if infoDownload}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<div class="flex items-center gap-2 mb-1">
<svg
class="w-3.5 h-3.5 text-green-400"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path d="M22 11.08V12a10 10 0 1 1-5.93-9.14" />
<polyline points="22 4 12 14.01 9 11.01" />
</svg>
<span class="text-white/40">Downloaded on:</span>
</div>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoDownload.nodeNames as nodeName}
<span
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
>
{nodeName}
</span>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>
{/if}
{/if}

View File

@@ -8,12 +8,16 @@
nodeThunderboltBridge,
type NodeInfo,
} from "$lib/stores/app.svelte";
import { getModelDownloadStatus } from "$lib/utils/downloads";
interface Props {
class?: string;
highlightedNodes?: Set<string>;
filteredNodes?: Set<string>;
onNodeClick?: (nodeId: string) => void;
downloadsData?: Record<string, unknown[]>;
activeModelId?: string | null;
onDownloadToNode?: (nodeId: string) => void;
}
let {
@@ -21,6 +25,9 @@
highlightedNodes = new Set(),
filteredNodes = new Set(),
onNodeClick,
downloadsData,
activeModelId = null,
onDownloadToNode,
}: Props = $props();
let svgContainer: SVGSVGElement | undefined = $state();
@@ -907,6 +914,95 @@
.attr("stroke-width", strokeWidth);
}
// --- Download Status Indicator (top-right of device icon) ---
if (activeModelId && downloadsData) {
const dlStatus = getModelDownloadStatus(
downloadsData,
nodeInfo.id,
activeModelId,
);
if (dlStatus) {
const indicatorSize = isMinimized ? 8 : 12;
const indicatorX =
nodeInfo.x + iconBaseWidth / 2 - indicatorSize * 0.3;
const indicatorY =
nodeInfo.y - iconBaseHeight / 2 - indicatorSize * 0.3;
if (dlStatus === "DownloadCompleted") {
// Green circle with white checkmark
const dlG = nodeG.append("g").attr("class", "download-indicator");
dlG.append("title").text("Downloaded on this node");
dlG
.append("circle")
.attr("cx", indicatorX)
.attr("cy", indicatorY)
.attr("r", indicatorSize)
.attr("fill", "#22c55e")
.attr("stroke", "#15803d")
.attr("stroke-width", 1);
// Checkmark path
const checkScale = indicatorSize / 12;
dlG
.append("path")
.attr(
"d",
`M${indicatorX - 4 * checkScale},${indicatorY} L${indicatorX - 1 * checkScale},${indicatorY + 3.5 * checkScale} L${indicatorX + 5 * checkScale},${indicatorY - 3.5 * checkScale}`,
)
.attr("stroke", "white")
.attr("stroke-width", 2 * checkScale)
.attr("fill", "none")
.attr("stroke-linecap", "round")
.attr("stroke-linejoin", "round");
} else if (onDownloadToNode) {
// Download arrow icon (not completed — pending/ongoing/failed)
const dlG = nodeG
.append("g")
.attr("class", "download-indicator")
.style("cursor", "pointer");
dlG.append("title").text("Download to this node");
dlG
.append("circle")
.attr("cx", indicatorX)
.attr("cy", indicatorY)
.attr("r", indicatorSize)
.attr("fill", "rgba(80, 80, 90, 0.9)")
.attr("stroke", "rgba(255,215,0,0.5)")
.attr("stroke-width", 1);
// Arrow-down path
const arrowScale = indicatorSize / 12;
dlG
.append("path")
.attr(
"d",
`M${indicatorX},${indicatorY - 4 * arrowScale} L${indicatorX},${indicatorY + 1.5 * arrowScale} M${indicatorX - 3 * arrowScale},${indicatorY - 1 * arrowScale} L${indicatorX},${indicatorY + 1.5 * arrowScale} L${indicatorX + 3 * arrowScale},${indicatorY - 1 * arrowScale} M${indicatorX - 4 * arrowScale},${indicatorY + 4 * arrowScale} L${indicatorX + 4 * arrowScale},${indicatorY + 4 * arrowScale}`,
)
.attr("stroke", "rgba(255,215,0,0.8)")
.attr("stroke-width", 1.5 * arrowScale)
.attr("fill", "none")
.attr("stroke-linecap", "round")
.attr("stroke-linejoin", "round");
dlG.on("click", (event: MouseEvent) => {
event.stopPropagation();
onDownloadToNode(nodeInfo.id);
});
dlG
.on("mouseenter", function () {
d3.select(this)
.select("circle")
.attr("stroke", "rgba(255,215,0,1)")
.attr("fill", "rgba(100, 100, 110, 0.9)");
})
.on("mouseleave", function () {
d3.select(this)
.select("circle")
.attr("stroke", "rgba(255,215,0,0.5)")
.attr("fill", "rgba(80, 80, 90, 0.9)");
});
}
}
}
// --- Vertical GPU Bar (right side of icon) ---
// Show in both full mode and minimized mode (scaled appropriately)
if (showFullLabels || isMinimized) {
@@ -1153,6 +1249,8 @@
const _hoveredNodeId = hoveredNodeId;
const _filteredNodes = filteredNodes;
const _highlightedNodes = highlightedNodes;
const _downloadsData = downloadsData;
const _activeModelId = activeModelId;
if (_data) {
renderGraph();
}

View File

@@ -6,3 +6,9 @@ export { default as ChatSidebar } from "./ChatSidebar.svelte";
export { default as ModelCard } from "./ModelCard.svelte";
export { default as MarkdownContent } from "./MarkdownContent.svelte";
export { default as ImageParamsPanel } from "./ImageParamsPanel.svelte";
export { default as FamilyLogos } from "./FamilyLogos.svelte";
export { default as FamilySidebar } from "./FamilySidebar.svelte";
export { default as HuggingFaceResultItem } from "./HuggingFaceResultItem.svelte";
export { default as ModelFilterPopover } from "./ModelFilterPopover.svelte";
export { default as ModelPickerGroup } from "./ModelPickerGroup.svelte";
export { default as ModelPickerModal } from "./ModelPickerModal.svelte";

View File

@@ -0,0 +1,97 @@
/**
* FavoritesStore - Manages favorite models with localStorage persistence
*/
import { browser } from "$app/environment";
const FAVORITES_KEY = "exo-favorite-models";
class FavoritesStore {
favorites = $state<Set<string>>(new Set());
constructor() {
if (browser) {
this.loadFromStorage();
}
}
private loadFromStorage() {
try {
const stored = localStorage.getItem(FAVORITES_KEY);
if (stored) {
const parsed = JSON.parse(stored) as string[];
this.favorites = new Set(parsed);
}
} catch (error) {
console.error("Failed to load favorites:", error);
}
}
private saveToStorage() {
try {
const array = Array.from(this.favorites);
localStorage.setItem(FAVORITES_KEY, JSON.stringify(array));
} catch (error) {
console.error("Failed to save favorites:", error);
}
}
add(baseModelId: string) {
const next = new Set(this.favorites);
next.add(baseModelId);
this.favorites = next;
this.saveToStorage();
}
remove(baseModelId: string) {
const next = new Set(this.favorites);
next.delete(baseModelId);
this.favorites = next;
this.saveToStorage();
}
toggle(baseModelId: string) {
if (this.favorites.has(baseModelId)) {
this.remove(baseModelId);
} else {
this.add(baseModelId);
}
}
isFavorite(baseModelId: string): boolean {
return this.favorites.has(baseModelId);
}
getAll(): string[] {
return Array.from(this.favorites);
}
getSet(): Set<string> {
return new Set(this.favorites);
}
hasAny(): boolean {
return this.favorites.size > 0;
}
clearAll() {
this.favorites = new Set();
this.saveToStorage();
}
}
export const favoritesStore = new FavoritesStore();
export const favorites = () => favoritesStore.favorites;
export const hasFavorites = () => favoritesStore.hasAny();
export const isFavorite = (baseModelId: string) =>
favoritesStore.isFavorite(baseModelId);
export const toggleFavorite = (baseModelId: string) =>
favoritesStore.toggle(baseModelId);
export const addFavorite = (baseModelId: string) =>
favoritesStore.add(baseModelId);
export const removeFavorite = (baseModelId: string) =>
favoritesStore.remove(baseModelId);
export const getFavorites = () => favoritesStore.getAll();
export const getFavoritesSet = () => favoritesStore.getSet();
export const clearFavorites = () => favoritesStore.clearAll();

View File

@@ -0,0 +1,152 @@
/**
* Shared utilities for parsing and querying download state.
*
* The download state from `/state` is shaped as:
* Record<NodeId, Array<TaggedDownloadEntry>>
*
* Each entry is a tagged union object like:
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
*/
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
function unwrapTagged(
obj: Record<string, unknown>,
): [string, Record<string, unknown>] | null {
const keys = Object.keys(obj);
if (keys.length !== 1) return null;
const tag = keys[0];
const payload = obj[tag];
if (!payload || typeof payload !== "object") return null;
return [tag, payload as Record<string, unknown>];
}
/** Extract the model ID string from a download entry's nested shard_metadata. */
export function extractModelIdFromDownload(
downloadPayload: Record<string, unknown>,
): string | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
if (!unwrapped) return null;
const [, shardData] = unwrapped;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== "object") return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
}
/** Extract the shard_metadata object from a download entry payload. */
export function extractShardMetadata(
downloadPayload: Record<string, unknown>,
): Record<string, unknown> | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
return shardMetadata as Record<string, unknown>;
}
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
export function getDownloadTag(
entry: unknown,
): [string, Record<string, unknown>] | null {
if (!entry || typeof entry !== "object") return null;
return unwrapTagged(entry as Record<string, unknown>);
}
/**
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
*/
function* iterNodeDownloads(
nodeDownloads: unknown[],
): Generator<[string, Record<string, unknown>, string]> {
for (const entry of nodeDownloads) {
const tagged = getDownloadTag(entry);
if (!tagged) continue;
const [tag, payload] = tagged;
const modelId = extractModelIdFromDownload(payload);
if (!modelId) continue;
yield [tag, payload, modelId];
}
}
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
export function isModelDownloadedOnNode(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): boolean {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return false;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
}
return false;
}
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
export function getNodesWithModelDownloaded(
downloadsData: Record<string, unknown[]>,
modelId: string,
): string[] {
const result: string[] = [];
for (const nodeId of Object.keys(downloadsData)) {
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
result.push(nodeId);
}
}
return result;
}
/**
* Find shard metadata for a model from any download entry across all nodes.
* Returns the first match found (completed entries are preferred).
*/
export function getShardMetadataForModel(
downloadsData: Record<string, unknown[]>,
modelId: string,
): Record<string, unknown> | null {
let fallback: Record<string, unknown> | null = null;
for (const nodeDownloads of Object.values(downloadsData)) {
if (!Array.isArray(nodeDownloads)) continue;
for (const [tag, payload, entryModelId] of iterNodeDownloads(
nodeDownloads,
)) {
if (entryModelId !== modelId) continue;
const shard = extractShardMetadata(payload);
if (!shard) continue;
if (tag === "DownloadCompleted") return shard;
if (!fallback) fallback = shard;
}
}
return fallback;
}
/**
* Get the download status tag for a specific model on a specific node.
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
*/
export function getModelDownloadStatus(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): string | null {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return null;
let best: string | null = null;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (entryModelId !== modelId) continue;
if (tag === "DownloadCompleted") return tag;
if (tag === "DownloadOngoing") best = tag;
else if (!best) best = tag;
}
return best;
}

View File

@@ -5,7 +5,13 @@
ChatMessages,
ChatSidebar,
ModelCard,
ModelPickerModal,
} from "$lib/components";
import {
favorites,
toggleFavorite,
getFavoritesSet,
} from "$lib/stores/favorites.svelte";
import {
hasStartedChat,
isTopologyMinimized,
@@ -33,9 +39,11 @@
toggleChatSidebarVisible,
thunderboltBridgeCycles,
nodeThunderboltBridge,
startDownload,
type DownloadProgress,
type PlacementPreview,
} from "$lib/stores/app.svelte";
import { getShardMetadataForModel } from "$lib/utils/downloads";
import HeaderNav from "$lib/components/HeaderNav.svelte";
import { fade, fly } from "svelte/transition";
import { cubicInOut } from "svelte/easing";
@@ -100,6 +108,11 @@
storage_size_megabytes?: number;
tasks?: string[];
hugging_face_id?: string;
is_custom?: boolean;
family?: string;
quantization?: string;
base_model?: string;
capabilities?: string[];
}>
>([]);
@@ -211,9 +224,11 @@
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state("");
// Model picker modal state
let isModelPickerOpen = $state(false);
// Favorites state (reactive)
const favoritesSet = $derived(getFavoritesSet());
// Slider dragging state
let isDraggingSlider = $state(false);
@@ -530,6 +545,58 @@
}
}
async function addModelFromPicker(modelId: string) {
const response = await fetch("/models/add", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ model_id: modelId }),
});
if (!response.ok) {
let message = `Failed to add model (${response.status}: ${response.statusText})`;
try {
const err = await response.json();
if (err.detail) message = err.detail;
} catch {
// use default message
}
throw new Error(message);
}
await fetchModels();
}
async function deleteCustomModel(modelId: string) {
try {
const response = await fetch(
`/models/custom/${encodeURIComponent(modelId)}`,
{ method: "DELETE" },
);
if (response.ok) {
await fetchModels();
}
} catch {
console.error("Failed to delete custom model");
}
}
function handleModelPickerSelect(modelId: string) {
selectPreviewModel(modelId);
saveLaunchDefaults();
isModelPickerOpen = false;
}
async function handleTopologyDownload(nodeId: string) {
if (!selectedModelId) return;
const shardMeta = getShardMetadataForModel(
downloadsData ?? {},
selectedModelId,
);
if (shardMeta) {
await startDownload(nodeId, shardMeta);
}
}
async function launchInstance(
modelId: string,
specificPreview?: PlacementPreview | null,
@@ -1668,6 +1735,9 @@
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
{downloadsData}
activeModelId={selectedModelId}
onDownloadToNode={handleTopologyDownload}
/>
<!-- Thunderbolt Bridge Cycle Warning -->
@@ -2360,14 +2430,12 @@
>
</div>
<!-- Model Dropdown (Custom) -->
<div class="flex-shrink-0 mb-3 relative">
<!-- Model Picker Button -->
<div class="flex-shrink-0 mb-3">
<button
type="button"
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen
? 'border-exo-yellow/70'
: ''}"
onclick={() => (isModelPickerOpen = true)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 relative"
>
{#if selectedModelId}
{@const foundModel = models.find(
@@ -2375,54 +2443,12 @@
)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
{@const isImageModel = modelSupportsImageGeneration(
foundModel.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
foundModel.id,
)}
<span
class="flex items-center justify-between gap-2 w-full pr-4"
>
<span
class="flex items-center gap-2 text-exo-light-gray truncate"
>
{#if isImageModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect
x="3"
y="3"
width="18"
height="18"
rx="2"
ry="2"
/>
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate"
>{foundModel.name || foundModel.id}</span
>
@@ -2439,142 +2465,24 @@
{:else}
<span class="text-white/50"> SELECT MODEL </span>
{/if}
</button>
<div
class="absolute right-3 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isModelDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-4 h-4 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
{#if isModelDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-40 cursor-default"
onclick={() => (isModelDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel -->
<div
class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-50 max-h-64 overflow-y-auto"
class="absolute right-3 top-1/2 -translate-y-1/2 pointer-events-none"
>
<!-- Search within dropdown -->
<div
class="sticky top-0 bg-exo-dark-gray border-b border-exo-medium-gray/30 p-2"
<svg
class="w-4 h-4 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<input
type="text"
placeholder="Search models..."
bind:value={modelDropdownSearch}
class="w-full bg-exo-dark-gray/60 border border-exo-medium-gray/30 rounded px-2 py-1.5 text-xs font-mono text-white/80 placeholder:text-white/40 focus:outline-none focus:border-exo-yellow/50"
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</div>
<!-- Options -->
<div class="py-1">
{#each sortedModels().filter((m) => !modelDropdownSearch || (m.name || m.id)
.toLowerCase()
.includes(modelDropdownSearch.toLowerCase())) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(
model.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
model.id,
)}
<button
type="button"
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = "";
}
}}
disabled={!modelCanFit}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedModelId ===
model.id
? 'bg-transparent text-exo-yellow cursor-pointer'
: modelCanFit
? 'text-white/80 hover:text-exo-yellow cursor-pointer'
: 'text-white/30 cursor-default'}"
>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image generation model"
>
<rect
x="3"
y="3"
width="18"
height="18"
rx="2"
ry="2"
/>
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image editing model"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span
class="flex-shrink-0 text-xs {modelCanFit
? 'text-white/50'
: 'text-red-400/60'}"
>
{sizeGB >= 1
? sizeGB.toFixed(0)
: sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">
No models found
</div>
{/each}
</div>
</svg>
</div>
{/if}
</button>
</div>
<!-- Configuration Options -->
@@ -2879,6 +2787,9 @@
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
{downloadsData}
activeModelId={selectedModelId}
onDownloadToNode={handleTopologyDownload}
/>
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
@@ -3354,3 +3265,24 @@
{/if}
</main>
</div>
<ModelPickerModal
isOpen={isModelPickerOpen}
{models}
{selectedModelId}
favorites={favoritesSet}
existingModelIds={new Set(models.map((m) => m.id))}
canModelFit={(modelId) => {
const model = models.find((m) => m.id === modelId);
return model ? hasEnoughMemory(model) : false;
}}
onSelect={handleModelPickerSelect}
onClose={() => (isModelPickerOpen = false)}
onToggleFavorite={toggleFavorite}
onAddModel={addModelFromPicker}
onDeleteModel={deleteCustomModel}
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
{downloadsData}
topologyNodes={data?.nodes}
/>

View File

@@ -10,6 +10,11 @@
deleteDownload,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
import {
extractModelIdFromDownload,
extractShardMetadata,
getDownloadTag,
} from "$lib/utils/downloads";
type FileProgress = {
name: string;
@@ -98,26 +103,7 @@
return Math.min(100, Math.max(0, value as number));
}
function extractModelIdFromDownload(
downloadPayload: Record<string, unknown>,
): string | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
const shardObj = shardMetadata as Record<string, unknown>;
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== "object") return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
}
// extractModelIdFromDownload imported from $lib/utils/downloads
function parseDownloadProgress(
payload: Record<string, unknown>,
@@ -197,14 +183,10 @@
for (const downloadWrapped of nodeEntries) {
if (!downloadWrapped || typeof downloadWrapped !== "object") continue;
const keys = Object.keys(downloadWrapped as Record<string, unknown>);
if (keys.length !== 1) continue;
const tagged = getDownloadTag(downloadWrapped);
if (!tagged) continue;
const downloadKind = keys[0];
const downloadPayload = (downloadWrapped as Record<string, unknown>)[
downloadKind
] as Record<string, unknown>;
if (!downloadPayload) continue;
const [downloadKind, downloadPayload] = tagged;
const modelId =
extractModelIdFromDownload(downloadPayload) ?? "unknown-model";
@@ -273,10 +255,7 @@
}
// Extract shard_metadata for use with download actions
const shardMetadata = (downloadPayload.shard_metadata ??
downloadPayload.shardMetadata) as
| Record<string, unknown>
| undefined;
const shardMetadata = extractShardMetadata(downloadPayload);
const entry: ModelEntry = {
modelId,

View File

@@ -10,6 +10,7 @@ PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
RESOURCES_DIR = PROJECT_ROOT / "resources"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
@@ -18,6 +19,9 @@ if not ENTRYPOINT.is_file():
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not RESOURCES_DIR.is_dir():
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
@@ -58,6 +62,7 @@ HIDDEN_IMPORTS = sorted(
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(RESOURCES_DIR), "resources"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15470210592
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5945589280
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21415799872
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11891178560
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33306978432
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23782357120
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-Edit-2509"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/DeepSeek-V3.1-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405874409472

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/DeepSeek-V3.1-8bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 765577920512

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.5-Air-8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 122406567936

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.5-Air-bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 229780750336

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 198556925568

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 286737579648

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-8bit-gs32"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 396963397248

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-Flash-4bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 19327352832

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-Flash-5bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 22548578304

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-Flash-6bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 26843545600

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-Flash-8bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 34359738368

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = "4bit"
base_model = "Kimi K2"
capabilities = ["text"]
[storage_size]
in_bytes = 620622774272

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Kimi-K2-Thinking"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 706522120192

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Kimi-K2.5"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 662498705408

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.2 1B"
capabilities = ["text"]
[storage_size]
in_bytes = 729808896

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.2 3B"
capabilities = ["text"]
[storage_size]
in_bytes = 1863319552

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.2 3B"
capabilities = ["text"]
[storage_size]
in_bytes = 3501195264

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 76799803392

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.1 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 4637851648

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 8954839040

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "bf16"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 16882073600

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.1-3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 100086644736

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.1-8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-0.6B-4bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 342884352

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-0.6B-8bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 698351616

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 141733920768

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 268435456000

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 17612931072

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 33279705088

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Coder 480B"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 289910292480

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Coder 480B"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 579820584960

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text"]
[storage_size]
in_bytes = 46976204800

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text"]
[storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 47080074240

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
n_layers = 36
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
family = "gpt-oss"
quantization = "MXFP4-Q8"
base_model = "GPT-OSS 120B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 70652212224

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
n_layers = 24
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
family = "gpt-oss"
quantization = "MXFP4-Q8"
base_model = "GPT-OSS 20B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 12025908224

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "fp16"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 144383672320

View File

@@ -1,4 +1,5 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
@@ -60,10 +61,37 @@ class DownloadCoordinator:
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
async def _check_internet_connection(self) -> None:
first_connection = True
while True:
await asyncio.sleep(10)
# Assume that internet connection is set to False on 443 errors.
if self.shard_downloader.internet_connection:
continue
self._test_internet_connection()
if first_connection and self.shard_downloader.internet_connection:
first_connection = False
self._tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
@@ -241,7 +269,7 @@ class DownloadCoordinator:
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info(
logger.debug(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
@@ -274,10 +302,10 @@ class DownloadCoordinator:
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info(
logger.debug(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(5 * 60) # 5 minutes
await anyio.sleep(60)
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"

View File

@@ -49,6 +49,10 @@ class HuggingFaceAuthenticationError(Exception):
"""Raised when HuggingFace returns 401/403 for a model download."""
class HuggingFaceRateLimitError(Exception):
"""429 Huggingface code"""
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
token = await get_hf_token()
if status_code == 401 and token is None:
@@ -154,49 +158,76 @@ async def seed_models(seed_dir: str | Path):
logger.error(traceback.format_exc())
_fetched_file_lists_this_session: set[str] = set()
async def fetch_file_list_with_cache(
model_id: ModelId, revision: str = "main", recursive: bool = False
model_id: ModelId,
revision: str = "main",
recursive: bool = False,
skip_internet: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
) -> list[FileListEntry]:
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
cache_key = f"{model_id.normalize()}--{revision}"
if cache_key in _fetched_file_lists_this_session and await aios.path.exists(
cache_file
):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
if skip_internet:
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
raise FileNotFoundError(
f"No internet connection and no cached file list for {model_id}"
)
# Always try fresh first
try:
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
model_id,
revision,
recursive=recursive,
on_connection_lost=on_connection_lost,
)
# Update cache with fresh data
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
_fetched_file_lists_this_session.add(cache_key)
return file_list
except Exception as e:
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
# No cache available, propagate the error
raise
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
async def fetch_file_list_with_retry(
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId,
revision: str = "main",
path: str = "",
recursive: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
) -> list[FileListEntry]:
n_attempts = 30
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
await asyncio.sleep(2.0**attempt)
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
@@ -216,7 +247,11 @@ async def _fetch_file_list(
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
raise HuggingFaceAuthenticationError(msg)
if response.status == 200:
elif response.status == 429:
raise HuggingFaceRateLimitError(
f"Couldn't download {model_id} because of HuggingFace rate limit."
)
elif response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
@@ -249,7 +284,7 @@ def create_http_session(
else:
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 1800
sock_read_timeout = 60
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(
@@ -324,8 +359,9 @@ async def download_file_with_retry(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> Path:
n_attempts = 30
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _download_file(
@@ -333,14 +369,19 @@ async def download_file_with_retry(
)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
except HuggingFaceRateLimitError as e:
if attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
await asyncio.sleep(2.0**attempt)
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
break
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
@@ -542,7 +583,9 @@ async def download_shard(
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
max_parallel_downloads: int = 8,
skip_download: bool = False,
skip_internet: bool = False,
allow_patterns: list[str] | None = None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.debug(f"Downloading {shard.model_card.model_id=}")
@@ -562,7 +605,11 @@ async def download_shard(
all_start_time = time.time()
file_list = await fetch_file_list_with_cache(
shard.model_card.model_id, revision, recursive=True
shard.model_card.model_id,
revision,
recursive=True,
skip_internet=skip_internet,
on_connection_lost=on_connection_lost,
)
filtered_file_list = list(
filter_repo_objects(
@@ -672,6 +719,7 @@ async def download_shard(
lambda curr_bytes, total_bytes, is_renamed: schedule_progress(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
)
if not skip_download:

View File

@@ -1,4 +1,5 @@
import asyncio
from asyncio import create_task
from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
@@ -7,7 +8,7 @@ from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -49,6 +50,10 @@ class SingletonShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -85,6 +90,10 @@ class CachedShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -142,6 +151,8 @@ class ResumableShardDownloader(ShardDownloader):
self.on_progress_wrapper,
max_parallel_downloads=self.max_parallel_downloads,
allow_patterns=allow_patterns,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
)
return target_dir
@@ -154,13 +165,24 @@ class ResumableShardDownloader(ShardDownloader):
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)
return await download_shard(
shard, self.on_progress_wrapper, skip_download=True
shard,
self.on_progress_wrapper,
skip_download=True,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
)
# Kick off download status coroutines concurrently
semaphore = asyncio.Semaphore(self.max_parallel_downloads)
async def download_with_semaphore(
model_card: ModelCard,
) -> tuple[Path, RepoDownloadProgress]:
async with semaphore:
return await _status_for_model(model_card.model_id)
tasks = [
asyncio.create_task(_status_for_model(model_card.model_id))
for model_card in MODEL_CARDS.values()
create_task(download_with_semaphore(model_card))
for model_card in await get_model_cards()
]
for task in asyncio.as_completed(tasks):

View File

@@ -16,6 +16,11 @@ from exo.shared.types.worker.shards import (
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
class ShardDownloader(ABC):
internet_connection: bool = False
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
@abstractmethod
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False

View File

@@ -66,7 +66,9 @@ def chat_request_to_text_generation(
return TextGenerationTaskParams(
model=request.model,
input=input_messages if input_messages else "",
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,

View File

@@ -141,7 +141,9 @@ def claude_request_to_text_generation(
return TextGenerationTaskParams(
model=request.model,
input=input_messages if input_messages else "",
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,

View File

@@ -43,10 +43,10 @@ def _extract_content(content: str | list[ResponseContentPart]) -> str:
def responses_request_to_text_generation(
request: ResponsesRequest,
) -> TextGenerationTaskParams:
input_value: str | list[InputMessage]
input_value: list[InputMessage]
built_chat_template: list[dict[str, Any]] | None = None
if isinstance(request.input, str):
input_value = request.input
input_value = [InputMessage(role="user", content=request.input)]
else:
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
@@ -95,7 +95,11 @@ def responses_request_to_text_generation(
}
)
input_value = input_messages if input_messages else ""
input_value = (
input_messages
if input_messages
else [InputMessage(role="user", content="")]
)
built_chat_template = chat_template_messages if chat_template_messages else None
return TextGenerationTaskParams(

View File

@@ -40,6 +40,7 @@ from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import (
DASHBOARD_DIR,
EXO_IMAGE_CACHE_DIR,
EXO_MAX_CHUNK_SIZE,
EXO_TRACING_CACHE_DIR,
@@ -47,12 +48,15 @@ from exo.shared.constants import (
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
delete_custom_card,
get_model_cards,
is_custom_card,
)
from exo.shared.tracing import TraceEvent, compute_stats, export_trace, load_trace_file
from exo.shared.types.api import (
AddCustomModelParams,
AdvancedImageParams,
BenchChatCompletionRequest,
BenchChatCompletionResponse,
@@ -70,6 +74,7 @@ from exo.shared.types.api import (
ErrorResponse,
FinishReason,
GenerationStats,
HuggingFaceSearchResult,
ImageData,
ImageEditsTaskParams,
ImageGenerationResponse,
@@ -138,7 +143,6 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
@@ -146,18 +150,6 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
def __init__(
self,
@@ -204,7 +196,7 @@ class API:
self.app.mount(
"/",
StaticFiles(
directory=find_dashboard(),
directory=DASHBOARD_DIR,
html=True,
),
name="dashboard",
@@ -269,6 +261,9 @@ class API:
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
self.app.delete("/models/custom/{model_id:path}")(self.delete_custom_model)
self.app.get("/models/search")(self.search_models)
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
@@ -381,10 +376,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
model_card = await ModelCard.load(model_id)
instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []
for sharding in (Sharding.Pipeline, Sharding.Tensor):
for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):
@@ -399,96 +391,93 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
instance=None,
error=str(exc),
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(sorted(placement_node_ids, key=str)):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
return PlacementPreviewResponse(previews=previews)
@@ -652,23 +641,21 @@ class API:
response = await self._collect_text_generation_with_stats(command.command_id)
return response
async def _resolve_and_validate_text_model(self, model: ModelId) -> ModelId:
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(model)
resolved = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved
instance.shard_assignments.model_id == model_id
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved)
await self._trigger_notify_user_to_download_model(model_id)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {resolved}",
detail=f"No instance found for model {model_id}",
)
return resolved
return model_id
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
@@ -1236,11 +1223,70 @@ class API:
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id),
family=card.family,
quantization=card.quantization,
base_model=card.base_model,
capabilities=card.capabilities,
)
for card in MODEL_CARDS.values()
for card in await get_model_cards()
]
)
async def add_custom_model(self, payload: AddCustomModelParams) -> ModelListModel:
"""Fetch a model from HuggingFace and save as a custom model card."""
try:
card = await ModelCard.fetch_from_hf(payload.model_id)
except Exception as exc:
raise HTTPException(
status_code=400, detail=f"Failed to fetch model: {exc}"
) from exc
return ModelListModel(
id=card.model_id,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=True,
)
async def delete_custom_model(self, model_id: ModelId) -> JSONResponse:
"""Delete a user-added custom model card."""
deleted = await delete_custom_card(model_id)
if not deleted:
raise HTTPException(status_code=404, detail="Custom model card not found")
return JSONResponse(
{"message": "Model card deleted", "model_id": str(model_id)}
)
async def search_models(
self, query: str = "", limit: int = 20
) -> list[HuggingFaceSearchResult]:
"""Search HuggingFace Hub for mlx-community models."""
from huggingface_hub import list_models
results = list_models(
search=query or None,
author="mlx-community",
sort="downloads",
limit=limit,
)
return [
HuggingFaceSearchResult(
id=m.id,
author=m.author or "",
downloads=m.downloads or 0,
likes=m.likes or 0,
last_modified=str(m.last_modified or ""),
tags=list(m.tags or []),
)
for m in results
]
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -28,7 +28,7 @@ from exo.shared.types.profiling import (
)
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.instances import (
InstanceMeta,
MlxRingInstance,
@@ -136,7 +136,9 @@ async def test_master():
command_id=CommandId(),
task_params=TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input="Hello, how are you?",
input=[
InputMessage(role="user", content="Hello, how are you?")
],
),
)
),
@@ -189,7 +191,7 @@ async def test_master():
assert isinstance(events[2].event.task, TextGenerationTask)
assert events[2].event.task.task_params == TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input="Hello, how are you?",
input=[InputMessage(role="user", content="Hello, how are you?")],
)
await master.shutdown()

View File

@@ -2,6 +2,8 @@ import os
import sys
from pathlib import Path
from exo.utils.dashboard_path import find_dashboard, find_resources
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
@@ -31,6 +33,14 @@ EXO_MODELS_DIR = (
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
)
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
RESOURCES_DIR = (
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
)
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
DASHBOARD_DIR = (
find_dashboard() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
)
# Log files (data/logs or cache)
EXO_LOG = EXO_CACHE_HOME / "exo.log"
@@ -48,6 +58,8 @@ LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024
EXO_CUSTOM_MODEL_CARDS_DIR = EXO_DATA_HOME / "custom_model_cards"
EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
EXO_TRACING_CACHE_DIR = EXO_CACHE_HOME / "traces"

View File

@@ -12,16 +12,47 @@ from pydantic import (
BaseModel,
Field,
PositiveInt,
ValidationError,
field_validator,
model_validator,
)
from tomlkit.exceptions import TOMLKitError
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.constants import (
EXO_CUSTOM_MODEL_CARDS_DIR,
EXO_ENABLE_IMAGE_MODELS,
RESOURCES_DIR,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
_card_cache: dict[str, "ModelCard"] = {}
# kinda ugly...
# TODO: load search path from config.toml
_custom_cards_dir = Path(str(EXO_CUSTOM_MODEL_CARDS_DIR))
_csp = [Path(RESOURCES_DIR) / "inference_model_cards", _custom_cards_dir]
if EXO_ENABLE_IMAGE_MODELS:
_csp.append(Path(RESOURCES_DIR) / "image_model_cards")
CARD_SEARCH_PATH = _csp
_card_cache: dict[ModelId, "ModelCard"] = {}
async def _refresh_card_cache():
for path in CARD_SEARCH_PATH:
async for toml_file in path.rglob("*.toml"):
try:
card = await ModelCard.load_from_path(toml_file)
_card_cache[card.model_id] = card
except (ValidationError, TOMLKitError):
pass
async def get_model_cards() -> list["ModelCard"]:
if len(_card_cache) == 0:
await _refresh_card_cache()
return list(_card_cache.values())
class ModelTask(str, Enum):
@@ -47,6 +78,10 @@ class ModelCard(CamelCaseModel):
supports_tensor: bool
tasks: list[ModelTask]
components: list[ComponentInfo] | None = None
family: str = ""
quantization: str = ""
base_model: str = ""
capabilities: list[str] = []
@field_validator("tasks", mode="before")
@classmethod
@@ -55,31 +90,37 @@ class ModelCard(CamelCaseModel):
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
py = self.model_dump(exclude_none=True)
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
async def save_to_custom_dir(self) -> None:
await aios.makedirs(str(_custom_cards_dir), exist_ok=True)
await self.save(_custom_cards_dir / (self.model_id.normalize() + ".toml"))
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
# Is it okay that model card.load defaults to network access if the card doesn't exist? do we want to be more explicit here?
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
for card in MODEL_CARDS.values():
if card.model_id == model_id:
return card
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if model_id not in _card_cache:
await _refresh_card_cache()
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
return await ModelCard.fetch_from_hf(model_id)
@staticmethod
async def fetch_from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
# TODO: failure if files do not exist
config_data = await fetch_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
mem_size_bytes = await fetch_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
@@ -89,544 +130,33 @@ class ModelCard(CamelCaseModel):
supports_tensor=config_data.supports_tensor,
tasks=[ModelTask.TextGeneration],
)
await mc.save_to_custom_dir()
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"deepseek-v3.1-8bit": ModelCard(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2-thinking": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2.5": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2.5"),
storage_size=Memory.from_gb(617),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.1
"llama-3.1-8b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-8bit": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-bf16": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-70b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.2
"llama-3.2-1b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.3
"llama-3.3-70b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-fp16": ModelCard(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# qwen3
"qwen3-0.6b": ModelCard(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-0.6b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
storage_size=Memory.from_mb(44900),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"glm-4.5-air-bf16": ModelCard(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-8bit-gs32": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"minimax-m2.1-3bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
}
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("exolabs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("exolabs/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
}
async def delete_custom_card(model_id: ModelId) -> bool:
"""Delete a user-added custom model card. Returns True if deleted."""
card_path = _custom_cards_dir / (ModelId(model_id).normalize() + ".toml")
if await card_path.exists():
await card_path.unlink()
_card_cache.pop(model_id, None)
return True
return False
def _generate_image_model_quant_variants(
def is_custom_card(model_id: ModelId) -> bool:
"""Check if a model card exists in the custom cards directory."""
import os
card_path = Path(str(EXO_CUSTOM_MODEL_CARDS_DIR)) / (
ModelId(model_id).normalize() + ".toml"
)
return os.path.isfile(str(card_path))
# TODO: quantizing and dynamically creating model cards
def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunction]
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
@@ -706,15 +236,6 @@ def _generate_image_model_quant_variants(
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
@@ -767,7 +288,7 @@ class ConfigData(BaseModel):
return data
async def get_config_data(model_id: ModelId) -> ConfigData:
async def fetch_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
from exo.download.download_utils import (
download_file_with_retry,
@@ -789,7 +310,7 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: ModelId) -> Memory:
async def fetch_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
from exo.download.download_utils import (
download_file_with_retry,

View File

@@ -42,6 +42,11 @@ class ModelListModel(BaseModel):
storage_size_megabytes: int = Field(default=0)
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
is_custom: bool = Field(default=False)
family: str = Field(default="")
quantization: str = Field(default="")
base_model: str = Field(default="")
capabilities: list[str] = Field(default_factory=list)
class ModelList(BaseModel):
@@ -201,6 +206,19 @@ class BenchChatCompletionRequest(ChatCompletionRequest):
pass
class AddCustomModelParams(BaseModel):
model_id: ModelId
class HuggingFaceSearchResult(BaseModel):
id: str
author: str = ""
downloads: int = 0
likes: int = 0
last_modified: str = ""
tags: list[str] = Field(default_factory=list)
class PlaceInstanceParams(BaseModel):
model_id: ModelId
sharding: Sharding = Sharding.Pipeline

View File

@@ -28,7 +28,7 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
"""
model: ModelId
input: str | list[InputMessage]
input: list[InputMessage]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None

View File

@@ -1,31 +1,45 @@
import os
import sys
from pathlib import Path
from typing import cast
def find_resources() -> Path:
resources = _find_resources_in_repo() or _find_resources_in_bundle()
if resources is None:
raise FileNotFoundError(
"Unable to locate resources. Did you clone the repo properly?"
)
return resources
def _find_resources_in_repo() -> Path | None:
current_module = Path(__file__).resolve()
for parent in current_module.parents:
build = parent / "resources"
if build.is_dir():
return build
return None
def _find_resources_in_bundle() -> Path | None:
frozen_root = cast(str | None, getattr(sys, "_MEIPASS", None))
if frozen_root is None:
return None
candidate = Path(frozen_root) / "resources"
if candidate.is_dir():
return candidate
return None
def find_dashboard() -> Path:
dashboard = (
_find_dashboard_in_env()
or _find_dashboard_in_repo()
or _find_dashboard_in_bundle()
)
dashboard = _find_dashboard_in_repo() or _find_dashboard_in_bundle()
if not dashboard:
raise FileNotFoundError(
"Unable to locate dashboard assets - make sure the dashboard has been built, or export DASHBOARD_DIR if you've built the dashboard elsewhere."
"Unable to locate dashboard assets - you probably forgot to run `cd dashboard && npm install && npm run build && cd ..`"
)
return dashboard
def _find_dashboard_in_env() -> Path | None:
env = os.environ.get("DASHBOARD_DIR")
if not env:
return None
resolved_env = Path(env).expanduser().resolve()
return resolved_env
def _find_dashboard_in_repo() -> Path | None:
current_module = Path(__file__).resolve()
for parent in current_module.parents:

View File

@@ -164,6 +164,12 @@ def _inner_model(model: nn.Module) -> nn.Module:
if isinstance(inner, nn.Module):
return inner
inner = getattr(model, "language_model", None)
if isinstance(inner, nn.Module):
inner_inner = getattr(inner, "model", None)
if isinstance(inner_inner, nn.Module):
return inner_inner
raise ValueError("Model must either have a 'model' or 'transformer' attribute")

View File

@@ -17,7 +17,7 @@ from exo.shared.types.api import (
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
@@ -100,7 +100,7 @@ def warmup_inference(
tokenizer=tokenizer,
task_params=TextGenerationTaskParams(
model=ModelId(""),
input=content,
input=[InputMessage(role="user", content=content)],
),
)

View File

@@ -384,6 +384,17 @@ def load_tokenizer_for_model_id(
eos_token_ids=eos_token_ids,
)
if "gemma-3" in model_id_lower:
gemma_3_eos_id = 1
gemma_3_end_of_turn_id = 106
if tokenizer.eos_token_ids is not None:
if gemma_3_end_of_turn_id not in tokenizer.eos_token_ids:
tokenizer.eos_token_ids = list(tokenizer.eos_token_ids) + [
gemma_3_end_of_turn_id
]
else:
tokenizer.eos_token_ids = [gemma_3_eos_id, gemma_3_end_of_turn_id]
return tokenizer
@@ -436,16 +447,11 @@ def apply_chat_template(
)
# Convert input to messages
if isinstance(task_params.input, str):
# Simple string input becomes a single user message
formatted_messages.append({"role": "user", "content": task_params.input})
else:
# List of InputMessage
for msg in task_params.input:
if not msg.content:
logger.warning("Received message with empty content, skipping")
continue
formatted_messages.append({"role": msg.role, "content": msg.content})
for msg in task_params.input:
if not msg.content:
logger.warning("Received message with empty content, skipping")
continue
formatted_messages.append({"role": msg.role, "content": msg.content})
prompt: str = tokenizer.apply_chat_template(
formatted_messages,

View File

@@ -918,15 +918,10 @@ def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
Extracts the first user input text and checks for debug triggers.
"""
prompt: str
if isinstance(task_params.input, str):
prompt = task_params.input
else:
# List of InputMessage - get first message content
if len(task_params.input) == 0:
logger.debug("Empty message list in debug prompt check")
return
prompt = task_params.input[0].content
if len(task_params.input) == 0:
logger.debug("Empty message list in debug prompt check")
return
prompt = task_params.input[0].content
if not prompt:
return

View File

@@ -14,7 +14,7 @@ from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
@@ -114,7 +114,7 @@ def run_gpt_oss_pipeline_device(
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=prompt_text,
input=[InputMessage(role="user", content=prompt_text)],
max_output_tokens=max_tokens,
)
@@ -182,7 +182,7 @@ def run_gpt_oss_tensor_parallel_device(
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=prompt_text,
input=[InputMessage(role="user", content=prompt_text)],
max_output_tokens=max_tokens,
)

View File

@@ -16,7 +16,7 @@ from exo.download.download_utils import (
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,
@@ -76,7 +76,7 @@ def get_test_models() -> list[ModelCard]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, ModelCard] = {}
for card in MODEL_CARDS.values():
for card in asyncio.run(get_model_cards()):
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
@@ -296,7 +296,7 @@ async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
card for card in await get_model_cards() if "kimi" in card.model_id.lower()
]
if not kimi_models:
@@ -343,7 +343,7 @@ async def test_kimi_tokenizer_specifically():
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_model_cards = [
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
card for card in await get_model_cards() if "glm" in card.model_id.lower()
]
if not glm_model_cards:

View File

@@ -2,7 +2,7 @@ from typing import cast
import exo.worker.plan as plan_mod
from exo.shared.types.tasks import Task, TaskId, TaskStatus, TextGeneration
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
RunnerIdle,
@@ -59,7 +59,9 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
task_params=TextGenerationTaskParams(
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
),
)
result = plan_mod.plan(
@@ -106,7 +108,9 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
task_params=TextGenerationTaskParams(
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
),
)
result = plan_mod.plan(
@@ -150,7 +154,9 @@ def test_plan_does_not_forward_tasks_for_other_instances():
instance_id=other_instance_id,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
task_params=TextGenerationTaskParams(
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
),
)
result = plan_mod.plan(
@@ -198,7 +204,9 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Complete,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
task_params=TextGenerationTaskParams(
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
),
)
other_task_id = TaskId("other-task")

View File

@@ -22,7 +22,7 @@ from exo.shared.types.tasks import (
TaskStatus,
TextGeneration,
)
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -86,7 +86,7 @@ SHUTDOWN_TASK = Shutdown(
CHAT_PARAMS = TextGenerationTaskParams(
model=MODEL_A_ID,
input="hello",
input=[InputMessage(role="user", content="hello")],
stream=True,
max_output_tokens=4,
temperature=0.0,

View File

@@ -10,7 +10,7 @@ from loguru import logger
from pydantic import BaseModel
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
@@ -23,7 +23,7 @@ from exo.shared.types.tasks import (
Task,
TextGeneration,
)
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
@@ -114,13 +114,13 @@ async def run_test(test: Tests):
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = ring_instance(test, hn)
i = await ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["rdma", "both"]:
i = jaccl_instance(test)
if test.kind in ["jaccl", "both"]:
i = await jaccl_instance(test)
if i is None:
yield "no model found"
return
@@ -145,7 +145,7 @@ async def run_test(test: Tests):
return StreamingResponse(run())
def ring_instance(test: Tests, hn: str) -> Instance | None:
async def ring_instance(test: Tests, hn: str) -> Instance | None:
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
@@ -158,11 +158,7 @@ def ring_instance(test: Tests, hn: str) -> Instance | None:
else:
raise ValueError(f"{hn} not in {test.devs}")
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
card = await ModelCard.load(test.model_id)
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52417,
@@ -200,7 +196,11 @@ async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input="What is the capital of France?",
input=[
InputMessage(
role="user", content="What is the capital of France?"
)
],
),
command_id=CommandId("yo"),
instance_id=iid,
@@ -230,12 +230,8 @@ async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
return []
def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = await ModelCard.load(test.model_id)
world_size = len(test.devs)
assert test.ibv_devs