Compare commits

...

17 Commits

Author SHA1 Message Date
Jake Hillion
0806c92ac9 deps: update mlx-lm
Summary of changes:

DSV3 MLA
Fix sliding window mask during generation
Fix batch mamba
Fix Step 3.5 Flash model conversion
Deepseek V3.2 implementation fixes
fix: handle GLM 4.7 tool call fallbacks
server: support chat_template_kwargs and top_logprobs
Add Step 3.5 Flash
allow creation of BatchRotatingKVCache instead of BatchKVCache when empty cache(s) are passed to BatchGenerator
enable loading custom models
fix cli
Support distributed inference in the server
fix mixed quant
Add LongCat Flash Lite
actually add cli
Fix Kimi K2.5 tool call handling
Fix for Exception - MultiLinear.to_quantized() missing 'mode'
Fix NemotronH config compatibility with HuggingFace format
Bump mlx version and version

Full changelog: 96699e6d...f18526f8
2026-02-05 16:34:41 +00:00
Evan Quiney
572e647908 better cancellation (#1388)
a lot of our cleanup logic wasn't running leading to bad shutdown states

## changes
- added `try: except` blocks around most task groups
- made the runner shutdown code synchronous
- abandon the MpReceiver's recv_async thread on cancellation
- this only occurs during runner shutdown, the queue closing from the
other end should terminate the mp.Queue, cleaning up the thread in its
own time. i could try other methods if this is not sufficient.

## outcome
ctrl-c just works now! minus the tokio panic of course :) no more
hypercorn lifespan errors though!
2026-02-05 15:22:33 +00:00
Evan Quiney
e59ebd986d set exo as the nix default package (#1391)
!!!
2026-02-05 15:15:52 +00:00
Alex Cheema
5c2f29f3f2 feat: show download availability in model picker (#1377)
## Motivation

Users browsing models in the picker need to know which models are
already downloaded and ready to run on their cluster, without having to
check the downloads page separately.

## Changes

- **ModelPickerModal.svelte**: Computes per-model download availability
by checking which nodes have `DownloadCompleted` entries and summing
their total RAM against the model's storage size. Passes availability
data to `ModelPickerGroup`. Enhances the info modal with a "Downloaded
on:" section showing node friendly names with green badges.
- **ModelPickerGroup.svelte**: Accepts new `downloadStatus` prop. Shows
a green checkmark-in-circle icon next to models that are downloaded on
sufficient nodes. Tooltip shows which nodes have the model.
- **+page.svelte**: Passes `downloadsData` and `topologyNodes` to
`ModelPickerModal`.

## Why It Works

The download state from `/state` already tracks per-node completed
downloads. The shared `getNodesWithModelDownloaded()` utility (from PR
#1375) finds nodes with `DownloadCompleted` entries for each model.
Total RAM is summed from the topology node data (using `ram_total`, not
`ram_available`) and compared to the model's `storage_size_megabytes` to
determine if there's enough aggregate memory. This is intentionally a
simple heuristic — not a full placement preview.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Open the model picker modal
- Verify downloaded models show a green checkmark icon
- Verify the checkmark appears dimmer for models downloaded on nodes
with insufficient total RAM
- Click the (i) info button on a downloaded model
- Verify "Downloaded on:" section appears with correct node names
- Verify models with no downloads show no indicator

### Automated Testing
- Dashboard builds successfully (`npm run build`)
- No new Python changes requiring type checking

> **Note:** This is a chained PR. Base branch is
`alexcheema/topology-download-indicators` (#1375).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 14:32:53 +00:00
Alex Cheema
ffe6396c91 Add Qwen3-Coder-Next model cards (#1367)
## Motivation

Qwen3-Coder-Next just dropped on mlx-community in several quantizations.
It's an 80B MoE model (Qwen3NextForCausalLM) which we already have
tensor parallelism support for via QwenShardingStrategy — just needs
model cards.

## Changes

Added model cards for all 5 available quantizations:
- `mlx-community/Qwen3-Coder-Next-4bit` (~46GB)
- `mlx-community/Qwen3-Coder-Next-5bit` (~58GB)
- `mlx-community/Qwen3-Coder-Next-6bit` (~69GB)
- `mlx-community/Qwen3-Coder-Next-8bit` (~89GB)
- `mlx-community/Qwen3-Coder-Next-bf16` (~158GB)

All with `supports_tensor = true` since the architecture is already
supported.

## Why It Works

`Qwen3NextForCausalLM` is already handled by QwenShardingStrategy in
auto_parallel.py and is in the supports_tensor allowlist in
model_cards.py. No code changes needed — just the TOML card files.

## Test Plan

### Manual Testing
<!-- n/a - model card addition only -->

### Automated Testing
- `basedpyright` — 0 errors
- `ruff check` — passes
- `nix fmt` — no changes
- `pytest` — 173 passed, 1 skipped


🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 13:37:18 +00:00
Jake Hillion
3a9baeb9db EXO: add CLI flags for root install/uninstall
The macOS app required user interaction via AppleScript prompts to
install or uninstall network configuration components, making automated
deployments difficult.

Added --install and --uninstall command line flags that execute the
network setup scripts directly when running as root, bypassing GUI
prompts. Created a new main.swift entry point that parses CLI arguments
and delegates to NetworkSetupHelper's new direct execution methods.

This enables headless installation via `sudo EXO --install` for
automated deployment scenarios while preserving the existing GUI
behavior when launched normally.

Test plan:

- Deployed to a machine that didn't have the content installed. Got
  blocked on the popup and EXO never launched.
- Relaunched EXO, confirmed it still never starts because of the popup.
- Ran `sudo /Applications/EXO.app/Contents/MacOS/EXO --install`
- Launched EXO - the API started as expected.
- Ran `sudo /Applications/EXO.app/Contents/MacOS/EXO --uninstall`
- Launched EXO - got the popup.
2026-02-05 13:27:46 +00:00
Alex Cheema
01b86a9e81 feat: add uncertainty visualization with token-level logprobs (#1180)
## Motivation

Adds uncertainty visualization to the chat interface, allowing users to
see token-level confidence scores and regenerate responses from any
point in the generation. This enables users to:
- Understand model confidence at each token
- Explore alternative completions by regenerating from uncertain tokens
- Debug and analyze model behavior

## Changes

### Uncertainty Visualization
- Add `TokenHeatmap` component showing token-level probability coloring
- Toggle uncertainty view per message with bar chart icon
- Display tooltip with probability, logprob, and top alternative tokens
on hover

### Regenerate from Token
- Add "Regenerate from here" button in token tooltip
- Use `continue_final_message` in chat template to continue within same
turn (no EOS tokens)
- Add `continue_from_prefix` flag to `ChatCompletionTaskParams`

### Request Cancellation
- Add `AbortController` to cancel in-flight requests when regenerating
mid-generation
- Handle `BrokenResourceError` server-side when client disconnects
gracefully

### Additional APIs
- Add Claude Messages API support (`/v1/messages`)
- Add OpenAI Responses API support (`/v1/responses`)

## Why It Works

- **Proper continuation**: Using `continue_final_message=True` instead
of `add_generation_prompt=True` keeps the assistant turn open, allowing
the model to continue naturally from the prefix without end-of-turn
markers
- **Clean cancellation**: AbortController aborts the HTTP request, and
server catches `BrokenResourceError` to avoid crashes
- **Stable hover during generation**: TokenHeatmap tracks hover by index
(stable across re-renders) with longer hide delay during generation

## Test Plan

### Manual Testing
<!-- Hardware: MacBook Pro M1 -->
- Send a message and verify logprobs are collected
- Enable uncertainty view and verify token coloring based on probability
- Hover over tokens to see tooltip with alternatives
- Click "Regenerate from here" on a token mid-response
- Verify the response continues naturally from that point
- Verify aborting mid-generation and regenerating works without server
crash

### Automated Testing
- Added tests for Claude Messages API adapter
- Added tests for OpenAI Responses API adapter

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan <evanev7@gmail.com>
2026-02-05 05:21:26 -08:00
rltakashige
221640a65b Acknowledge task after runner status is updated (#1381)
## Motivation

Duplicate tasks are still observed.

## Changes

Moved task acknowledgement to after the runner has changed its status.

## Why It Works

Tasks now remain pending until the runner has updated its status.

## Test Plan

### Manual Testing
Seems to work fine from manual testing. Hard to test a race condition
though.

### Automated Testing
Updated the event ordering test.
2026-02-05 12:00:37 +00:00
ciaranbor
6177550c34 Ciaran/parallel cfg (#1361)
## Motivation

Enable parallel classifier-free guidance (CFG) for Qwen image models.
CFG requires two forward passes (positive/negative prompts) - this
allows them to run on separate nodes simultaneously, reducing latency.

## Changes

  - Added uses_cfg flag to ModelCard to identify CFG-based models
- Extended PipelineShardMetadata with CFG topology fields (cfg_rank,
cfg_world_size, peer device info)
- Updated placement to create two CFG groups with reversed ordering
(places CFG peers as ring neighbors)
- Refactored DiffusionRunner to process CFG branches separately with
exchange at last pipeline stage
- Added get_cfg_branch_data() to PromptData for single-branch embeddings
  - Fixed seed handling in API for distributed consistency
  - Fixed image yield to only emit from CFG rank 0 at last stage
  - Increased num_sync_steps_factor from 0.125 to 0.25 for Qwen

## Why It Works

- 2 nodes + CFG: Both run all layers, process different CFG branches in
parallel
  - 4+ even nodes + CFG: Hybrid - 2 CFG groups × N/2 pipeline stages
  - Odd nodes or non-CFG: Falls back to pure pipeline parallelism
 
 Ring topology places CFG peers as neighbors to enable direct exchange.

## Test Plan

### Manual Testing

Verified performance gain for Qwen-Image for 2 node and 4 node cluster.
Non-CFG models still work

### Automated Testing
 
Added tests in test_placement_utils.py covering 2-node CFG parallel,
4-node hybrid, odd-node fallback, and non-CFG pipeline modes.
2026-02-04 21:16:35 +00:00
Evan Quiney
7b6cad94c6 add resources dir to nix (#1376)
add resources directory to the nix exo package, and fixes the env for the dashboard dir
2026-02-04 16:38:43 +00:00
Alex Cheema
41ed7afb3b feat: add model picker modal with grouped models and HF Hub search (#1369)
## Motivation

Reimplements the model picker modal from #1191 on top of the custom
model support branch. Replaces the inline model dropdown with a
full-featured modal that groups models by base model, supports
filtering, favorites, and HuggingFace Hub search.

## Changes

**Backend:**
- Add `family`, `quantization`, `base_model`, `capabilities` metadata
fields to `ModelCard` and all 40 TOML model cards
- Pass new fields through `ModelListModel` and `get_models()` API
response
- Add `GET /models/search` endpoint using
`huggingface_hub.list_models()`

**Dashboard (7 new files):**
- `ModelPickerModal.svelte` — Main modal with search, family filtering,
HuggingFace Hub tab
- `ModelPickerGroup.svelte` — Expandable model group row with
quantization variants
- `FamilySidebar.svelte` — Vertical sidebar with family icons (All,
Favorites, Hub, model families)
- `FamilyLogos.svelte` — SVG icons for each model family
- `ModelFilterPopover.svelte` — Capability and size range filters
- `HuggingFaceResultItem.svelte` — HF search result item with
download/like counts
- `favorites.svelte.ts` — localStorage-backed favorites store

**Integration:**
- Replace inline dropdown in `+page.svelte` with button that opens
`ModelPickerModal`
- Custom models shown in Hub tab with delete support

**Polish:**
- Real brand logos (Meta, Qwen, DeepSeek, OpenAI, GLM, MiniMax, Kimi,
HuggingFace) from Simple Icons / LobeHub
- Clean SVG stroke icons for capabilities (thinking, code, vision, image
gen)
- Consistent `border-exo-yellow/10` borders, descriptive tooltips
throughout
- Cluster memory (used/total) shown in modal header
- Selected model highlight with checkmark for both single and
multi-variant groups
- Cursor pointer on all interactive elements, fix filter popover
click-outside bug
- Custom models now appear in All tab alongside built-in models

## Bug Fix: Gemma 3 EOS tokens

Also included in this branch: fix for Gemma 3 models generating infinite
`<end_of_turn>` tokens. The tokenizer's `eos_token_ids` was missing
token ID 106 (`<end_of_turn>`), so generation never stopped. The fix
appends this token to the EOS list after loading the tokenizer. Also
handles `eos_token_ids` being a `set` (not just a `list`).

## Why It Works

Model metadata (family, capabilities, etc.) is stored directly in TOML
cards rather than derived from heuristics, ensuring accuracy. The modal
groups models by `base_model` field so quantization variants appear
together. Custom models are separated into the Hub tab since they lack
grouping metadata.

## Test Plan

### Manual Testing
- Open dashboard, click model selector to open modal
- Browse models by family sidebar, search, and filters
- Expand model groups to see quantization variants
- Star favorites and verify persistence across page reloads
- Navigate to Hub tab, search and add models
- Verify error messages shown for invalid model IDs
- Run a Gemma 3 model and verify generation stops at `<end_of_turn>`

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `nix fmt` — clean
- `uv run pytest src/` — 173 passed
- `cd dashboard && npm run build` — builds successfully

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 05:56:23 -08:00
Alex Cheema
2063278906 feat: add custom HuggingFace model support (#1368)
## Motivation

Users should be able to run any HuggingFace model, not just the ones we
ship TOML cards for. Continues the aim of #1191 with a minimal
implementation on top of the current TOML model card system.

Custom cards are saved to `~/.exo/custom_model_cards/` rather than the
bundled `resources/inference_model_cards/` because `RESOURCES_DIR` is
read-only in PyInstaller bundles (`sys._MEIPASS`). This also fixes
`fetch_from_hf` which was saving cards to the wrong path (`resources/`
root instead of `resources/inference_model_cards/`).

## Changes

- Add `EXO_CUSTOM_MODEL_CARDS_DIR` constant
(`~/.exo/custom_model_cards/`)
- Update `model_cards.py`: add custom dir to search path, fix
`save_to_custom_dir`, add `delete_custom_card`/`is_custom_card`
- Add `POST /models/add` and `DELETE /models/custom/{model_id}` API
endpoints
- Add `is_custom` field to `ModelListModel` API response
- Dashboard: add custom model input form in dropdown, delete button for
custom models, show actual API errors, auto-select newly added model

## Why It Works

Two separate directories for model cards: the bundled read-only
`resources/inference_model_cards/` for built-in cards, and user-writable
`~/.exo/custom_model_cards/` for custom cards. Both are scanned when
listing models. This works in all environments including PyInstaller
bundles where `RESOURCES_DIR` points to `sys._MEIPASS`.

## Test Plan

### Manual Testing
- Add a custom model via the dropdown (e.g.
`mlx-community/Llama-3.2-1B-Instruct-4bit`)
- Verify it appears in the model list with the delete (x) button
- Delete it and verify it disappears
- Try adding an invalid model ID and verify the actual error is shown

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `uv run pytest src/` — passes
- `cd dashboard && npm run build` — builds

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 05:06:15 -08:00
rltakashige
a0f4f36355 Reduce reliance on internet (#1363)
## Motivation

Offline users currently have to wait for every retry to fail before
being able to launch a model.
For users that restart clusters often or share API keys between devices,
we also spam HuggingFace with downloads every 5 minutes.
These issues are caused by _emit_existing_download_progress being
inefficient.

## Changes

- Only query HuggingFace once while EXO is running (assumption being
that a change should only be reflected on a new EXO session)
- Only query HuggingFace when there is an internet connection (polling
connectivity every 10 seconds)
- Request download progress if we switch from no connectivity ->
connected to reduce the wait.
- Reduce download progress sleep as it's no longer expensive (queries
cache most of the time).
- Reduce retries as 30 is way too many.

## Test Plan

### Manual Testing
Manually tested the behaviour.

### Automated Testing
None, should I add any? We do have some tests for this folder, but they
are probably not too helpful.
2026-02-03 20:03:29 +00:00
Alex Cheema
acb97127bf Normalize TextGenerationTaskParams.input to list[InputMessage] (#1360)
## Motivation

With the addition of the Responses API, we introduced `str |
list[InputMessage]` as the type for `TextGenerationTaskParams.input`
since the Responses API supports sending input as a plain string. But
there was no reason to leak that flexibility past the API adapter
boundary — it just meant every downstream consumer had to do `if
isinstance(messages, str):` checks, adding complexity for no benefit.

## Changes

- Changed `TextGenerationTaskParams.input` from `str |
list[InputMessage]` to `list[InputMessage]`
- Each API adapter (Chat Completions, Claude Messages, Responses) now
normalizes to `list[InputMessage]` at the boundary
- Removed `isinstance(task_params.input, str)` branches in
`utils_mlx.py` and `runner.py`
- Wrapped string inputs in `[InputMessage(role="user", content=...)]` in
the warmup path and all test files

## Why It Works

The API adapters are the only place where we deal with raw user input
formats. By normalizing there, all downstream code (worker, runner, MLX
engine) can just assume `list[InputMessage]` and skip the type-checking
branches. The type system (`basedpyright`) catches any missed call sites
at compile time.

## Test Plan

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `nix fmt` — applied
- `uv run pytest` — 174 passed, 1 skipped

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 06:01:56 -08:00
Evan Quiney
d90605f198 migrate model cards to .toml files (#1354) 2026-02-03 12:32:06 +00:00
Evan Quiney
f400b4d7c5 fix InstanceViewModel.swift (#1359)
wasn't caught when we merged the API changes
2026-02-02 18:43:27 +00:00
Evan Quiney
d97bca88e6 improve distributed testing (#1300)
Our distributed test now does a full query cycle for every model loaded
onto the relevant machine. This will help find bugs early, as it already
has found one with Qwen3 Next! I didn't write down what the error was
though. Gooooooood luck with that!

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-02 18:25:39 +00:00
131 changed files with 6001 additions and 1573 deletions

View File

@@ -142,4 +142,6 @@ jobs:
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
export EXO_DASHBOARD_DIR="$PWD/dashboard/"
export EXO_RESOURCES_DIR="$PWD/resources"
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

1
.gitignore vendored
View File

@@ -31,3 +31,4 @@ dashboard/.svelte-kit/
# host config snapshots
hosts_*.json
.swp

View File

@@ -108,6 +108,7 @@ class TokenizerWrapper:
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
eos_token_ids: list[int] | set[int] | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
@@ -117,7 +118,7 @@ class TokenizerWrapper:
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | None = ...,
eos_token_ids: list[int] | set[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,

View File

@@ -14,7 +14,6 @@ import SwiftUI
import UserNotifications
import os.log
@main
struct EXOApp: App {
@StateObject private var controller: ExoProcessController
@StateObject private var stateService: ClusterStateService

View File

@@ -288,6 +288,61 @@ enum NetworkSetupHelper {
"""
}
/// Direct install without GUI (requires root).
/// Returns true on success, false on failure.
static func installDirectly() -> Bool {
let script = makeInstallerScript()
return runShellDirectly(script)
}
/// Direct uninstall without GUI (requires root).
/// Returns true on success, false on failure.
static func uninstallDirectly() -> Bool {
let script = makeUninstallScript()
return runShellDirectly(script)
}
/// Run a shell script directly via Process (no AppleScript, requires root).
/// Returns true on success, false on failure.
private static func runShellDirectly(_ script: String) -> Bool {
let process = Process()
process.executableURL = URL(fileURLWithPath: "/bin/bash")
process.arguments = ["-c", script]
let outputPipe = Pipe()
let errorPipe = Pipe()
process.standardOutput = outputPipe
process.standardError = errorPipe
do {
try process.run()
process.waitUntilExit()
let outputData = outputPipe.fileHandleForReading.readDataToEndOfFile()
let errorData = errorPipe.fileHandleForReading.readDataToEndOfFile()
if let output = String(data: outputData, encoding: .utf8), !output.isEmpty {
print(output)
}
if let errorOutput = String(data: errorData, encoding: .utf8), !errorOutput.isEmpty {
fputs(errorOutput, stderr)
}
if process.terminationStatus == 0 {
logger.info("Shell script completed successfully")
return true
} else {
logger.error("Shell script failed with exit code \(process.terminationStatus)")
return false
}
} catch {
logger.error(
"Failed to run shell script: \(error.localizedDescription, privacy: .public)")
fputs("Error: \(error.localizedDescription)\n", stderr)
return false
}
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript =
script

View File

@@ -216,7 +216,7 @@ struct InstanceTaskViewModel: Identifiable, Equatable {
let promptPreview: String?
let errorMessage: String?
let subtitle: String?
let parameters: ChatCompletionTaskParameters?
let parameters: TextGenerationTaskParameters?
var title: String {
switch kind {

85
app/EXO/EXO/main.swift Normal file
View File

@@ -0,0 +1,85 @@
//
// main.swift
// EXO
//
// Created by Jake Hillion on 2026-02-03.
//
import Foundation
/// Command line options for the EXO app
enum CLICommand {
case install
case uninstall
case help
case none
}
/// Parse command line arguments to determine the CLI command
func parseArguments() -> CLICommand {
let args = CommandLine.arguments
if args.contains("--help") || args.contains("-h") {
return .help
}
if args.contains("--install") {
return .install
}
if args.contains("--uninstall") {
return .uninstall
}
return .none
}
/// Print usage information
func printUsage() {
let programName = (CommandLine.arguments.first as NSString?)?.lastPathComponent ?? "EXO"
print(
"""
Usage: \(programName) [OPTIONS]
Options:
--install Install EXO network configuration (requires root)
--uninstall Uninstall EXO network configuration (requires root)
--help, -h Show this help message
When run without options, starts the normal GUI application.
Examples:
sudo \(programName) --install Install network components as root
sudo \(programName) --uninstall Remove network components as root
""")
}
/// Check if running as root
func isRunningAsRoot() -> Bool {
return getuid() == 0
}
// Main entry point
let command = parseArguments()
switch command {
case .help:
printUsage()
exit(0)
case .install:
if !isRunningAsRoot() {
fputs("Error: --install requires root privileges. Run with sudo.\n", stderr)
exit(1)
}
let success = NetworkSetupHelper.installDirectly()
exit(success ? 0 : 1)
case .uninstall:
if !isRunningAsRoot() {
fputs("Error: --uninstall requires root privileges. Run with sudo.\n", stderr)
exit(1)
}
let success = NetworkSetupHelper.uninstallDirectly()
exit(success ? 0 : 1)
case .none:
// Start normal GUI application
EXOApp.main()
}

View File

@@ -6,11 +6,13 @@
deleteMessage,
editAndRegenerate,
regenerateLastResponse,
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
interface Props {
class?: string;
@@ -99,6 +101,23 @@
let copiedMessageId = $state<string | null>(null);
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
// Uncertainty heatmap toggle
let heatmapMessageIds = $state<Set<string>>(new Set());
function toggleHeatmap(messageId: string) {
const next = new Set(heatmapMessageIds);
if (next.has(messageId)) {
next.delete(messageId);
} else {
next.add(messageId);
}
heatmapMessageIds = next;
}
function isHeatmapVisible(messageId: string): boolean {
return heatmapMessageIds.has(messageId);
}
function formatTimestamp(timestamp: number): string {
return new Date(timestamp).toLocaleTimeString("en-US", {
hour12: false,
@@ -548,13 +567,23 @@
>
</div>
{:else if message.content || (loading && !message.attachments?.some((a) => a.type === "generated-image"))}
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>
{#if isHeatmapVisible(message.id) && message.tokens && message.tokens.length > 0}
<TokenHeatmap
tokens={message.tokens}
isGenerating={loading &&
isLastAssistantMessage(message.id)}
onRegenerateFrom={(tokenIndex) =>
regenerateFromToken(message.id, tokenIndex)}
/>
{:else}
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>
{/if}
{/if}
{/if}
</div>
@@ -629,6 +658,35 @@
</button>
{/if}
<!-- Uncertainty heatmap toggle (assistant messages with tokens) -->
{#if message.role === "assistant" && message.tokens && message.tokens.length > 0}
<button
onclick={() => toggleHeatmap(message.id)}
class="p-1.5 transition-colors rounded cursor-pointer {isHeatmapVisible(
message.id,
)
? 'text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
title={isHeatmapVisible(message.id)
? "Hide uncertainty heatmap"
: "Show uncertainty heatmap"}
>
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
/>
</svg>
</button>
{/if}
<!-- Regenerate button (last assistant message only) -->
{#if message.role === "assistant" && isLastAssistantMessage(message.id) && !loading}
<button

View File

@@ -0,0 +1,73 @@
<script lang="ts">
type FamilyLogoProps = {
family: string;
class?: string;
};
let { family, class: className = "" }: FamilyLogoProps = $props();
</script>
{#if family === "favorites"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{:else if family === "llama" || family === "meta"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M6.915 4.03c-1.968 0-3.683 1.28-4.871 3.113C.704 9.208 0 11.883 0 14.449c0 .706.07 1.369.21 1.973a6.624 6.624 0 0 0 .265.86 5.297 5.297 0 0 0 .371.761c.696 1.159 1.818 1.927 3.593 1.927 1.497 0 2.633-.671 3.965-2.444.76-1.012 1.144-1.626 2.663-4.32l.756-1.339.186-.325c.061.1.121.196.183.3l2.152 3.595c.724 1.21 1.665 2.556 2.47 3.314 1.046.987 1.992 1.22 3.06 1.22 1.075 0 1.876-.355 2.455-.843a3.743 3.743 0 0 0 .81-.973c.542-.939.861-2.127.861-3.745 0-2.72-.681-5.357-2.084-7.45-1.282-1.912-2.957-2.93-4.716-2.93-1.047 0-2.088.467-3.053 1.308-.652.57-1.257 1.29-1.82 2.05-.69-.875-1.335-1.547-1.958-2.056-1.182-.966-2.315-1.303-3.454-1.303zm10.16 2.053c1.147 0 2.188.758 2.992 1.999 1.132 1.748 1.647 4.195 1.647 6.4 0 1.548-.368 2.9-1.839 2.9-.58 0-1.027-.23-1.664-1.004-.496-.601-1.343-1.878-2.832-4.358l-.617-1.028a44.908 44.908 0 0 0-1.255-1.98c.07-.109.141-.224.211-.327 1.12-1.667 2.118-2.602 3.358-2.602zm-10.201.553c1.265 0 2.058.791 2.675 1.446.307.327.737.871 1.234 1.579l-1.02 1.566c-.757 1.163-1.882 3.017-2.837 4.338-1.191 1.649-1.81 1.817-2.486 1.817-.524 0-1.038-.237-1.383-.794-.263-.426-.464-1.13-.464-2.046 0-2.221.63-4.535 1.66-6.088.454-.687.964-1.226 1.533-1.533a2.264 2.264 0 0 1 1.088-.285z"
/>
</svg>
{:else if family === "qwen"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12.604 1.34c.393.69.784 1.382 1.174 2.075a.18.18 0 00.157.091h5.552c.174 0 .322.11.446.327l1.454 2.57c.19.337.24.478.024.837-.26.43-.513.864-.76 1.3l-.367.658c-.106.196-.223.28-.04.512l2.652 4.637c.172.301.111.494-.043.77-.437.785-.882 1.564-1.335 2.34-.159.272-.352.375-.68.37-.777-.016-1.552-.01-2.327.016a.099.099 0 00-.081.05 575.097 575.097 0 01-2.705 4.74c-.169.293-.38.363-.725.364-.997.003-2.002.004-3.017.002a.537.537 0 01-.465-.271l-1.335-2.323a.09.09 0 00-.083-.049H4.982c-.285.03-.553-.001-.805-.092l-1.603-2.77a.543.543 0 01-.002-.54l1.207-2.12a.198.198 0 000-.197 550.951 550.951 0 01-1.875-3.272l-.79-1.395c-.16-.31-.173-.496.095-.965.465-.813.927-1.625 1.387-2.436.132-.234.304-.334.584-.335a338.3 338.3 0 012.589-.001.124.124 0 00.107-.063l2.806-4.895a.488.488 0 01.422-.246c.524-.001 1.053 0 1.583-.006L11.704 1c.341-.003.724.032.9.34zm-3.432.403a.06.06 0 00-.052.03L6.254 6.788a.157.157 0 01-.135.078H3.253c-.056 0-.07.025-.041.074l5.81 10.156c.025.042.013.062-.034.063l-2.795.015a.218.218 0 00-.2.116l-1.32 2.31c-.044.078-.021.118.068.118l5.716.008c.046 0 .08.02.104.061l1.403 2.454c.046.081.092.082.139 0l5.006-8.76.783-1.382a.055.055 0 01.096 0l1.424 2.53a.122.122 0 00.107.062l2.763-.02a.04.04 0 00.035-.02.041.041 0 000-.04l-2.9-5.086a.108.108 0 010-.113l.293-.507 1.12-1.977c.024-.041.012-.062-.035-.062H9.2c-.059 0-.073-.026-.043-.077l1.434-2.505a.107.107 0 000-.114L9.225 1.774a.06.06 0 00-.053-.031zm6.29 8.02c.046 0 .058.02.034.06l-.832 1.465-2.613 4.585a.056.056 0 01-.05.029.058.058 0 01-.05-.029L8.498 9.841c-.02-.034-.01-.052.028-.054l.216-.012 6.722-.012z"
/>
</svg>
{:else if family === "deepseek"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M23.748 4.482c-.254-.124-.364.113-.512.234-.051.039-.094.09-.137.136-.372.397-.806.657-1.373.626-.829-.046-1.537.214-2.163.848-.133-.782-.575-1.248-1.247-1.548-.352-.156-.708-.311-.955-.65-.172-.241-.219-.51-.305-.774-.055-.16-.11-.323-.293-.35-.2-.031-.278.136-.356.276-.313.572-.434 1.202-.422 1.84.027 1.436.633 2.58 1.838 3.393.137.093.172.187.129.323-.082.28-.18.552-.266.833-.055.179-.137.217-.329.14a5.526 5.526 0 01-1.736-1.18c-.857-.828-1.631-1.742-2.597-2.458a11.365 11.365 0 00-.689-.471c-.985-.957.13-1.743.388-1.836.27-.098.093-.432-.779-.428-.872.004-1.67.295-2.687.684a3.055 3.055 0 01-.465.137 9.597 9.597 0 00-2.883-.102c-1.885.21-3.39 1.102-4.497 2.623C.082 8.606-.231 10.684.152 12.85c.403 2.284 1.569 4.175 3.36 5.653 1.858 1.533 3.997 2.284 6.438 2.14 1.482-.085 3.133-.284 4.994-1.86.47.234.962.327 1.78.397.63.059 1.236-.03 1.705-.128.735-.156.684-.837.419-.961-2.155-1.004-1.682-.595-2.113-.926 1.096-1.296 2.746-2.642 3.392-7.003.05-.347.007-.565 0-.845-.004-.17.035-.237.23-.256a4.173 4.173 0 001.545-.475c1.396-.763 1.96-2.015 2.093-3.517.02-.23-.004-.467-.247-.588zM11.581 18c-2.089-1.642-3.102-2.183-3.52-2.16-.392.024-.321.471-.235.763.09.288.207.486.371.739.114.167.192.416-.113.603-.673.416-1.842-.14-1.897-.167-1.361-.802-2.5-1.86-3.301-3.307-.774-1.393-1.224-2.887-1.298-4.482-.02-.386.093-.522.477-.592a4.696 4.696 0 011.529-.039c2.132.312 3.946 1.265 5.468 2.774.868.86 1.525 1.887 2.202 2.891.72 1.066 1.494 2.082 2.48 2.914.348.292.625.514.891.677-.802.09-2.14.11-3.054-.614zm1-6.44a.306.306 0 01.415-.287.302.302 0 01.2.288.306.306 0 01-.31.307.303.303 0 01-.304-.308zm3.11 1.596c-.2.081-.399.151-.59.16a1.245 1.245 0 01-.798-.254c-.274-.23-.47-.358-.552-.758a1.73 1.73 0 01.016-.588c.07-.327-.008-.537-.239-.727-.187-.156-.426-.199-.688-.199a.559.559 0 01-.254-.078c-.11-.054-.2-.19-.114-.358.028-.054.16-.186.192-.21.356-.202.767-.136 1.146.016.352.144.618.408 1.001.782.391.451.462.576.685.914.176.265.336.537.445.848.067.195-.019.354-.25.452z"
/>
</svg>
{:else if family === "openai" || family === "gpt-oss"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M22.2819 9.8211a5.9847 5.9847 0 0 0-.5157-4.9108 6.0462 6.0462 0 0 0-6.5098-2.9A6.0651 6.0651 0 0 0 4.9807 4.1818a5.9847 5.9847 0 0 0-3.9977 2.9 6.0462 6.0462 0 0 0 .7427 7.0966 5.98 5.98 0 0 0 .511 4.9107 6.051 6.051 0 0 0 6.5146 2.9001A5.9847 5.9847 0 0 0 13.2599 24a6.0557 6.0557 0 0 0 5.7718-4.2058 5.9894 5.9894 0 0 0 3.9977-2.9001 6.0557 6.0557 0 0 0-.7475-7.0729zm-9.022 12.6081a4.4755 4.4755 0 0 1-2.8764-1.0408l.1419-.0804 4.7783-2.7582a.7948.7948 0 0 0 .3927-.6813v-6.7369l2.02 1.1686a.071.071 0 0 1 .038.052v5.5826a4.504 4.504 0 0 1-4.4945 4.4944zm-9.6607-4.1254a4.4708 4.4708 0 0 1-.5346-3.0137l.142.0852 4.783 2.7582a.7712.7712 0 0 0 .7806 0l5.8428-3.3685v2.3324a.0804.0804 0 0 1-.0332.0615L9.74 19.9502a4.4992 4.4992 0 0 1-6.1408-1.6464zM2.3408 7.8956a4.485 4.485 0 0 1 2.3655-1.9728V11.6a.7664.7664 0 0 0 .3879.6765l5.8144 3.3543-2.0201 1.1685a.0757.0757 0 0 1-.071 0l-4.8303-2.7865A4.504 4.504 0 0 1 2.3408 7.872zm16.5963 3.8558L13.1038 8.364 15.1192 7.2a.0757.0757 0 0 1 .071 0l4.8303 2.7913a4.4944 4.4944 0 0 1-.6765 8.1042v-5.6772a.79.79 0 0 0-.407-.667zm2.0107-3.0231l-.142-.0852-4.7735-2.7818a.7759.7759 0 0 0-.7854 0L9.409 9.2297V6.8974a.0662.0662 0 0 1 .0284-.0615l4.8303-2.7866a4.4992 4.4992 0 0 1 6.6802 4.66zM8.3065 12.863l-2.02-1.1638a.0804.0804 0 0 1-.038-.0567V6.0742a4.4992 4.4992 0 0 1 7.3757-3.4537l-.142.0805L8.704 5.459a.7948.7948 0 0 0-.3927.6813zm1.0976-2.3654l2.602-1.4998 2.6069 1.4998v2.9994l-2.5974 1.4997-2.6067-1.4997Z"
/>
</svg>
{:else if family === "glm"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M11.991 23.503a.24.24 0 00-.244.248.24.24 0 00.244.249.24.24 0 00.245-.249.24.24 0 00-.22-.247l-.025-.001zM9.671 5.365a1.697 1.697 0 011.099 2.132l-.071.172-.016.04-.018.054c-.07.16-.104.32-.104.498-.035.71.47 1.279 1.186 1.314h.366c1.309.053 2.338 1.173 2.286 2.523-.052 1.332-1.152 2.38-2.478 2.327h-.174c-.715.018-1.274.64-1.239 1.368 0 .124.018.23.053.337.209.373.54.658.96.8.75.23 1.517-.125 1.9-.782l.018-.035c.402-.64 1.17-.96 1.92-.711.854.284 1.378 1.226 1.099 2.167a1.661 1.661 0 01-2.077 1.102 1.711 1.711 0 01-.907-.711l-.017-.035c-.2-.323-.463-.58-.851-.711l-.056-.018a1.646 1.646 0 00-1.954.746 1.66 1.66 0 01-1.065.764 1.677 1.677 0 01-1.989-1.279c-.209-.906.332-1.83 1.257-2.043a1.51 1.51 0 01.296-.035h.018c.68-.071 1.151-.622 1.116-1.333a1.307 1.307 0 00-.227-.693 2.515 2.515 0 01-.366-1.403 2.39 2.39 0 01.366-1.208c.14-.195.21-.444.227-.693.018-.71-.506-1.261-1.186-1.332l-.07-.018a1.43 1.43 0 01-.299-.07l-.05-.019a1.7 1.7 0 01-1.047-2.114 1.68 1.68 0 012.094-1.101zm-5.575 10.11c.26-.264.639-.367.994-.27.355.096.633.379.728.74.095.362-.007.748-.267 1.013-.402.41-1.053.41-1.455 0a1.062 1.062 0 010-1.482zm14.845-.294c.359-.09.738.024.992.297.254.274.344.665.237 1.025-.107.36-.396.634-.756.718-.551.128-1.1-.22-1.23-.781a1.05 1.05 0 01.757-1.26zm-.064-4.39c.314.32.49.753.49 1.206 0 .452-.176.886-.49 1.206-.315.32-.74.5-1.185.5-.444 0-.87-.18-1.184-.5a1.727 1.727 0 010-2.412 1.654 1.654 0 012.369 0zm-11.243.163c.364.484.447 1.128.218 1.691a1.665 1.665 0 01-2.188.923c-.855-.36-1.26-1.358-.907-2.228a1.68 1.68 0 011.33-1.038c.593-.08 1.183.169 1.547.652zm11.545-4.221c.368 0 .708.2.892.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.892.524c-.568 0-1.03-.47-1.03-1.048 0-.579.462-1.048 1.03-1.048zm-14.358 0c.368 0 .707.2.891.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.891.524c-.569 0-1.03-.47-1.03-1.048 0-.579.461-1.048 1.03-1.048zm10.031-1.475c.925 0 1.675.764 1.675 1.706s-.75 1.705-1.675 1.705-1.674-.763-1.674-1.705c0-.942.75-1.706 1.674-1.706zm-2.626-.684c.362-.082.653-.356.761-.718a1.062 1.062 0 00-.238-1.028 1.017 1.017 0 00-.996-.294c-.547.14-.881.7-.752 1.257.13.558.675.907 1.225.783zm0 16.876c.359-.087.644-.36.75-.72a1.062 1.062 0 00-.237-1.019 1.018 1.018 0 00-.985-.301 1.037 1.037 0 00-.762.717c-.108.361-.017.754.239 1.028.245.263.606.377.953.305l.043-.01zM17.19 3.5a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64a.631.631 0 00-.628.64c0 .355.28.64.628.64zm-10.38 0a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64a.631.631 0 00-.628.64c0 .355.279.64.628.64zm-5.182 7.852a.631.631 0 00-.628.64c0 .354.28.639.628.639a.63.63 0 00.627-.606l.001-.034a.62.62 0 00-.628-.64zm5.182 9.13a.631.631 0 00-.628.64c0 .355.279.64.628.64a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm10.38.018a.631.631 0 00-.628.64c0 .355.28.64.628.64a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64zm5.182-9.148a.631.631 0 00-.628.64c0 .354.279.639.628.639a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm-.384-4.992a.24.24 0 00.244-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249c0 .142.122.249.244.249zM11.991.497a.24.24 0 00.245-.248A.24.24 0 0011.99 0a.24.24 0 00-.244.249c0 .133.108.236.223.247l.021.001zM2.011 6.36a.24.24 0 00.245-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249.24.24 0 00.244.249zm0 11.263a.24.24 0 00-.243.248.24.24 0 00.244.249.24.24 0 00.244-.249.252.252 0 00-.244-.248zm19.995-.018a.24.24 0 00-.245.248.24.24 0 00.245.25.24.24 0 00.244-.25.252.252 0 00-.244-.248z"
/>
</svg>
{:else if family === "minimax"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M16.278 2c1.156 0 2.093.927 2.093 2.07v12.501a.74.74 0 00.744.709.74.74 0 00.743-.709V9.099a2.06 2.06 0 012.071-2.049A2.06 2.06 0 0124 9.1v6.561a.649.649 0 01-.652.645.649.649 0 01-.653-.645V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v7.472a2.037 2.037 0 01-2.048 2.026 2.037 2.037 0 01-2.048-2.026v-12.5a.785.785 0 00-.788-.753.785.785 0 00-.789.752l-.001 15.904A2.037 2.037 0 0113.441 22a2.037 2.037 0 01-2.048-2.026V18.04c0-.356.292-.645.652-.645.36 0 .652.289.652.645v1.934c0 .263.142.506.372.638.23.131.514.131.744 0a.734.734 0 00.372-.638V4.07c0-1.143.937-2.07 2.093-2.07zm-5.674 0c1.156 0 2.093.927 2.093 2.07v11.523a.648.648 0 01-.652.645.648.648 0 01-.652-.645V4.07a.785.785 0 00-.789-.78.785.785 0 00-.789.78v14.013a2.06 2.06 0 01-2.07 2.048 2.06 2.06 0 01-2.071-2.048V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v3.8a2.06 2.06 0 01-2.071 2.049A2.06 2.06 0 010 12.9v-1.378c0-.357.292-.646.652-.646.36 0 .653.29.653.646V12.9c0 .418.343.757.766.757s.766-.339.766-.757V9.099a2.06 2.06 0 012.07-2.048 2.06 2.06 0 012.071 2.048v8.984c0 .419.343.758.767.758.423 0 .766-.339.766-.758V4.07c0-1.143.937-2.07 2.093-2.07z"
/>
</svg>
{:else if family === "kimi"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19.738 5.776c.163-.209.306-.4.457-.585.07-.087.064-.153-.004-.244-.655-.861-.717-1.817-.34-2.787.283-.73.909-1.072 1.674-1.145.477-.045.945.004 1.379.236.57.305.902.77 1.01 1.412.086.512.07 1.012-.075 1.508-.257.878-.888 1.333-1.753 1.448-.718.096-1.446.108-2.17.157-.056.004-.113 0-.178 0z"
/>
<path
d="M17.962 1.844h-4.326l-3.425 7.81H5.369V1.878H1.5V22h3.87v-8.477h6.824a3.025 3.025 0 002.743-1.75V22h3.87v-8.477a3.87 3.87 0 00-3.588-3.86v-.01h-2.125a3.94 3.94 0 002.323-2.12l2.545-5.689z"
/>
</svg>
{:else if family === "huggingface"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12.025 1.13c-5.77 0-10.449 4.647-10.449 10.378 0 1.112.178 2.181.503 3.185.064-.222.203-.444.416-.577a.96.96 0 0 1 .524-.15c.293 0 .584.124.84.284.278.173.48.408.71.694.226.282.458.611.684.951v-.014c.017-.324.106-.622.264-.874s.403-.487.762-.543c.3-.047.596.06.787.203s.31.313.4.467c.15.257.212.468.233.542.01.026.653 1.552 1.657 2.54.616.605 1.01 1.223 1.082 1.912.055.537-.096 1.059-.38 1.572.637.121 1.294.187 1.967.187.657 0 1.298-.063 1.921-.178-.287-.517-.44-1.041-.384-1.581.07-.69.465-1.307 1.081-1.913 1.004-.987 1.647-2.513 1.657-2.539.021-.074.083-.285.233-.542.09-.154.208-.323.4-.467a1.08 1.08 0 0 1 .787-.203c.359.056.604.29.762.543s.247.55.265.874v.015c.225-.34.457-.67.683-.952.23-.286.432-.52.71-.694.257-.16.547-.284.84-.285a.97.97 0 0 1 .524.151c.228.143.373.388.43.625l.006.04a10.3 10.3 0 0 0 .534-3.273c0-5.731-4.678-10.378-10.449-10.378M8.327 6.583a1.5 1.5 0 0 1 .713.174 1.487 1.487 0 0 1 .617 2.013c-.183.343-.762-.214-1.102-.094-.38.134-.532.914-.917.71a1.487 1.487 0 0 1 .69-2.803m7.486 0a1.487 1.487 0 0 1 .689 2.803c-.385.204-.536-.576-.916-.71-.34-.12-.92.437-1.103.094a1.487 1.487 0 0 1 .617-2.013 1.5 1.5 0 0 1 .713-.174m-10.68 1.55a.96.96 0 1 1 0 1.921.96.96 0 0 1 0-1.92m13.838 0a.96.96 0 1 1 0 1.92.96.96 0 0 1 0-1.92M8.489 11.458c.588.01 1.965 1.157 3.572 1.164 1.607-.007 2.984-1.155 3.572-1.164.196-.003.305.12.305.454 0 .886-.424 2.328-1.563 3.202-.22-.756-1.396-1.366-1.63-1.32q-.011.001-.02.006l-.044.026-.01.008-.03.024q-.018.017-.035.036l-.032.04a1 1 0 0 0-.058.09l-.014.025q-.049.088-.11.19a1 1 0 0 1-.083.116 1.2 1.2 0 0 1-.173.18q-.035.029-.075.058a1.3 1.3 0 0 1-.251-.243 1 1 0 0 1-.076-.107c-.124-.193-.177-.363-.337-.444-.034-.016-.104-.008-.2.022q-.094.03-.216.087-.06.028-.125.063l-.13.074q-.067.04-.136.086a3 3 0 0 0-.135.096 3 3 0 0 0-.26.219 2 2 0 0 0-.12.121 2 2 0 0 0-.106.128l-.002.002a2 2 0 0 0-.09.132l-.001.001a1.2 1.2 0 0 0-.105.212q-.013.036-.024.073c-1.139-.875-1.563-2.317-1.563-3.203 0-.334.109-.457.305-.454m.836 10.354c.824-1.19.766-2.082-.365-3.194-1.13-1.112-1.789-2.738-1.789-2.738s-.246-.945-.806-.858-.97 1.499.202 2.362c1.173.864-.233 1.45-.685.64-.45-.812-1.683-2.896-2.322-3.295s-1.089-.175-.938.647 2.822 2.813 2.562 3.244-1.176-.506-1.176-.506-2.866-2.567-3.49-1.898.473 1.23 2.037 2.16c1.564.932 1.686 1.178 1.464 1.53s-3.675-2.511-4-1.297c-.323 1.214 3.524 1.567 3.287 2.405-.238.839-2.71-1.587-3.216-.642-.506.946 3.49 2.056 3.522 2.064 1.29.33 4.568 1.028 5.713-.624m5.349 0c-.824-1.19-.766-2.082.365-3.194 1.13-1.112 1.789-2.738 1.789-2.738s.246-.945.806-.858.97 1.499-.202 2.362c-1.173.864.233 1.45.685.64.451-.812 1.683-2.896 2.322-3.295s1.089-.175.938.647-2.822 2.813-2.562 3.244 1.176-.506 1.176-.506 2.866-2.567 3.49-1.898-.473 1.23-2.037 2.16c-1.564.932-1.686 1.178-1.464 1.53s3.675-2.511 4-1.297c.323 1.214-3.524 1.567-3.287 2.405.238.839 2.71-1.587 3.216-.642.506.946-3.49 2.056-3.522 2.064-1.29.33-4.568 1.028-5.713-.624"
/>
</svg>
{:else}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z"
/>
</svg>
{/if}

View File

@@ -0,0 +1,142 @@
<script lang="ts">
import FamilyLogos from "./FamilyLogos.svelte";
type FamilySidebarProps = {
families: string[];
selectedFamily: string | null;
hasFavorites: boolean;
onSelect: (family: string | null) => void;
};
let { families, selectedFamily, hasFavorites, onSelect }: FamilySidebarProps =
$props();
// Family display names
const familyNames: Record<string, string> = {
favorites: "Favorites",
huggingface: "Hub",
llama: "Meta",
qwen: "Qwen",
deepseek: "DeepSeek",
"gpt-oss": "OpenAI",
glm: "GLM",
minimax: "MiniMax",
kimi: "Kimi",
};
function getFamilyName(family: string): string {
return (
familyNames[family] || family.charAt(0).toUpperCase() + family.slice(1)
);
}
</script>
<div
class="flex flex-col gap-1 py-2 px-1 border-r border-exo-yellow/10 bg-exo-medium-gray/30 min-w-[64px]"
>
<!-- All models (no filter) -->
<button
type="button"
onclick={() => onSelect(null)}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
null
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="All models"
>
<svg
class="w-5 h-5 {selectedFamily === null
? 'text-exo-yellow'
: 'text-white/50 group-hover:text-white/70'}"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M4 8h4V4H4v4zm6 12h4v-4h-4v4zm-6 0h4v-4H4v4zm0-6h4v-4H4v4zm6 0h4v-4h-4v4zm6-10v4h4V4h-4zm-6 4h4V4h-4v4zm6 6h4v-4h-4v4zm0 6h4v-4h-4v4z"
/>
</svg>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === null
? 'text-exo-yellow'
: 'text-white/40 group-hover:text-white/60'}">All</span
>
</button>
<!-- Favorites (only show if has favorites) -->
{#if hasFavorites}
<button
type="button"
onclick={() => onSelect("favorites")}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
'favorites'
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="Show favorited models"
>
<FamilyLogos
family="favorites"
class={selectedFamily === "favorites"
? "text-amber-400"
: "text-white/50 group-hover:text-amber-400/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'favorites'
? 'text-amber-400'
: 'text-white/40 group-hover:text-white/60'}">Faves</span
>
</button>
{/if}
<!-- HuggingFace Hub -->
<button
type="button"
onclick={() => onSelect("huggingface")}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
'huggingface'
? 'bg-orange-500/20 border-l-2 border-orange-400'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="Browse and add models from Hugging Face"
>
<FamilyLogos
family="huggingface"
class={selectedFamily === "huggingface"
? "text-orange-400"
: "text-white/50 group-hover:text-orange-400/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'huggingface'
? 'text-orange-400'
: 'text-white/40 group-hover:text-white/60'}">Hub</span
>
</button>
<div class="h-px bg-exo-yellow/10 my-1"></div>
<!-- Model families -->
{#each families as family}
<button
type="button"
onclick={() => onSelect(family)}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
family
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title={getFamilyName(family)}
>
<FamilyLogos
{family}
class={selectedFamily === family
? "text-exo-yellow"
: "text-white/50 group-hover:text-white/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 truncate max-w-full {selectedFamily ===
family
? 'text-exo-yellow'
: 'text-white/40 group-hover:text-white/60'}"
>
{getFamilyName(family)}
</span>
</button>
{/each}
</div>

View File

@@ -0,0 +1,151 @@
<script lang="ts">
interface HuggingFaceModel {
id: string;
author: string;
downloads: number;
likes: number;
last_modified: string;
tags: string[];
}
type HuggingFaceResultItemProps = {
model: HuggingFaceModel;
isAdded: boolean;
isAdding: boolean;
onAdd: () => void;
onSelect: () => void;
downloadedOnNodes?: string[];
};
let {
model,
isAdded,
isAdding,
onAdd,
onSelect,
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
if (num >= 1000000) {
return `${(num / 1000000).toFixed(1)}M`;
} else if (num >= 1000) {
return `${(num / 1000).toFixed(1)}k`;
}
return num.toString();
}
// Extract model name from full ID (e.g., "mlx-community/Llama-3.2-1B" -> "Llama-3.2-1B")
const modelName = $derived(model.id.split("/").pop() || model.id);
</script>
<div
class="flex items-center justify-between gap-3 px-3 py-2.5 hover:bg-white/5 transition-colors border-b border-white/5 last:border-b-0"
>
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="text-sm font-mono text-white truncate" title={model.id}
>{modelName}</span
>
{#if downloadedOnNodes.length > 0}
<span
class="flex-shrink-0"
title={`Downloaded on ${downloadedOnNodes.join(", ")}`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{#if isAdded}
<span
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"
>Added</span
>
{/if}
</div>
<div class="flex items-center gap-3 mt-0.5 text-xs text-white/40">
<span class="truncate">{model.author}</span>
<span
class="flex items-center gap-1 shrink-0"
title="Downloads in the last 30 days"
>
<svg
class="w-3 h-3"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"
/>
</svg>
{formatNumber(model.downloads)}
</span>
<span
class="flex items-center gap-1 shrink-0"
title="Community likes on Hugging Face"
>
<svg
class="w-3 h-3"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4.318 6.318a4.5 4.5 0 000 6.364L12 20.364l7.682-7.682a4.5 4.5 0 00-6.364-6.364L12 7.636l-1.318-1.318a4.5 4.5 0 00-6.364 0z"
/>
</svg>
{formatNumber(model.likes)}
</span>
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
{#if isAdded}
<button
type="button"
onclick={onSelect}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-yellow/10 text-exo-yellow border border-exo-yellow/30 hover:bg-exo-yellow/20 transition-colors rounded cursor-pointer"
>
Select
</button>
{:else}
<button
type="button"
onclick={onAdd}
disabled={isAdding}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded cursor-pointer disabled:opacity-50 disabled:cursor-not-allowed"
>
{#if isAdding}
<span class="flex items-center gap-1.5">
<span
class="w-3 h-3 border-2 border-orange-400 border-t-transparent rounded-full animate-spin"
></span>
Adding...
</span>
{:else}
+ Add
{/if}
</button>
{/if}
</div>
</div>

View File

@@ -0,0 +1,213 @@
<script lang="ts">
import { fly } from "svelte/transition";
import { cubicOut } from "svelte/easing";
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
type ModelFilterPopoverProps = {
filters: FilterState;
onChange: (filters: FilterState) => void;
onClear: () => void;
onClose: () => void;
};
let { filters, onChange, onClear, onClose }: ModelFilterPopoverProps =
$props();
// Available capabilities
const availableCapabilities = [
{ id: "text", label: "Text" },
{ id: "thinking", label: "Thinking" },
{ id: "code", label: "Code" },
{ id: "vision", label: "Vision" },
];
// Size ranges
const sizeRanges = [
{ label: "< 10GB", min: 0, max: 10 },
{ label: "10-50GB", min: 10, max: 50 },
{ label: "50-200GB", min: 50, max: 200 },
{ label: "> 200GB", min: 200, max: 10000 },
];
function toggleCapability(cap: string) {
const next = filters.capabilities.includes(cap)
? filters.capabilities.filter((c) => c !== cap)
: [...filters.capabilities, cap];
onChange({ ...filters, capabilities: next });
}
function selectSizeRange(range: { min: number; max: number } | null) {
// Toggle off if same range is clicked
if (
filters.sizeRange &&
range &&
filters.sizeRange.min === range.min &&
filters.sizeRange.max === range.max
) {
onChange({ ...filters, sizeRange: null });
} else {
onChange({ ...filters, sizeRange: range });
}
}
function handleClickOutside(e: MouseEvent) {
const target = e.target as HTMLElement;
if (
!target.closest(".filter-popover") &&
!target.closest(".filter-toggle")
) {
onClose();
}
}
</script>
<svelte:window onclick={handleClickOutside} />
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="filter-popover absolute right-0 top-full mt-2 w-64 bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-xl z-10"
transition:fly={{ y: -10, duration: 200, easing: cubicOut }}
onclick={(e) => e.stopPropagation()}
role="dialog"
aria-label="Filter options"
>
<div class="p-3 space-y-4">
<!-- Capabilities -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Capabilities</h4>
<div class="flex flex-wrap gap-1.5">
{#each availableCapabilities as cap}
{@const isSelected = filters.capabilities.includes(cap.id)}
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {isSelected
? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() => toggleCapability(cap.id)}
>
{#if cap.id === "text"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "thinking"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "code"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M16 18l6-6-6-6M8 6l-6 6 6 6"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "vision"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z"
stroke-linecap="round"
stroke-linejoin="round"
/><circle cx="12" cy="12" r="3" /></svg
>
{/if}
<span class="ml-1">{cap.label}</span>
</button>
{/each}
</div>
</div>
<!-- Downloaded only -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Availability</h4>
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly
? 'bg-green-500/20 text-green-400 border border-green-500/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() =>
onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}
>
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="ml-1">Downloaded</span>
</button>
</div>
<!-- Size range -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>
<div class="flex flex-wrap gap-1.5">
{#each sizeRanges as range}
{@const isSelected =
filters.sizeRange &&
filters.sizeRange.min === range.min &&
filters.sizeRange.max === range.max}
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {isSelected
? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() => selectSizeRange(range)}
>
{range.label}
</button>
{/each}
</div>
</div>
<!-- Clear button -->
<button
type="button"
class="w-full py-1.5 text-xs font-mono text-white/50 hover:text-white/70 hover:bg-white/5 rounded transition-colors"
onclick={onClear}
>
Clear all filters
</button>
</div>
</div>

View File

@@ -0,0 +1,401 @@
<script lang="ts">
interface ModelInfo {
id: string;
name?: string;
storage_size_megabytes?: number;
base_model?: string;
quantization?: string;
supports_tensor?: boolean;
capabilities?: string[];
family?: string;
is_custom?: boolean;
}
interface ModelGroup {
id: string;
name: string;
capabilities: string[];
family: string;
variants: ModelInfo[];
smallestVariant: ModelInfo;
hasMultipleVariants: boolean;
}
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
type ModelPickerGroupProps = {
group: ModelGroup;
isExpanded: boolean;
isFavorite: boolean;
selectedModelId: string | null;
canModelFit: (id: string) => boolean;
onToggleExpand: () => void;
onSelectModel: (modelId: string) => void;
onToggleFavorite: (baseModelId: string) => void;
onShowInfo: (group: ModelGroup) => void;
downloadStatusMap?: Map<string, DownloadAvailability>;
};
let {
group,
isExpanded,
isFavorite,
selectedModelId,
canModelFit,
onToggleExpand,
onSelectModel,
onToggleFavorite,
onShowInfo,
downloadStatusMap,
}: ModelPickerGroupProps = $props();
// Group-level download status: show if any variant is downloaded
const groupDownloadStatus = $derived.by(() => {
if (!downloadStatusMap || downloadStatusMap.size === 0) return undefined;
// Return the first available entry (prefer "available" ones)
for (const avail of downloadStatusMap.values()) {
if (avail.available) return avail;
}
return downloadStatusMap.values().next().value;
});
// Format storage size
function formatSize(mb: number | undefined): string {
if (!mb) return "";
if (mb >= 1024) {
return `${(mb / 1024).toFixed(0)}GB`;
}
return `${mb}MB`;
}
// Check if any variant can fit
const anyVariantFits = $derived(
group.variants.some((v) => canModelFit(v.id)),
);
// Check if this group's model is currently selected (for single-variant groups)
const isMainSelected = $derived(
!group.hasMultipleVariants &&
group.variants.some((v) => v.id === selectedModelId),
);
</script>
<div
class="border-b border-white/5 last:border-b-0 {!anyVariantFits
? 'opacity-50'
: ''}"
>
<!-- Main row -->
<div
class="flex items-center gap-2 px-3 py-2.5 transition-colors {anyVariantFits
? 'hover:bg-white/5 cursor-pointer'
: 'cursor-not-allowed'} {isMainSelected
? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'
: 'border-l-2 border-transparent'}"
onclick={() => {
if (group.hasMultipleVariants) {
onToggleExpand();
} else {
const modelId = group.variants[0]?.id;
if (modelId && canModelFit(modelId)) {
onSelectModel(modelId);
}
}
}}
role="button"
tabindex="0"
onkeydown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
if (group.hasMultipleVariants) {
onToggleExpand();
} else {
const modelId = group.variants[0]?.id;
if (modelId && canModelFit(modelId)) {
onSelectModel(modelId);
}
}
}
}}
>
<!-- Expand/collapse chevron (for groups with variants) -->
{#if group.hasMultipleVariants}
<svg
class="w-4 h-4 text-white/40 transition-transform duration-200 flex-shrink-0 {isExpanded
? 'rotate-90'
: ''}"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M8.59 16.59L13.17 12 8.59 7.41 10 6l6 6-6 6-1.41-1.41z" />
</svg>
{:else}
<div class="w-4 flex-shrink-0"></div>
{/if}
<!-- Model name -->
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="font-mono text-sm text-white truncate">
{group.name}
</span>
<!-- Capability icons -->
{#each group.capabilities.filter((c) => c !== "text") as cap}
{#if cap === "thinking"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports Thinking"
>
<path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
{:else if cap === "code"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports code generation"
>
<path
d="M16 18l6-6-6-6M8 6l-6 6 6 6"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
{:else if cap === "vision"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports image input"
>
<path
d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z"
stroke-linecap="round"
stroke-linejoin="round"
/>
<circle cx="12" cy="12" r="3" />
</svg>
{:else if cap === "image_gen"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports image generation"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<path d="M21 15l-5-5L5 21" />
</svg>
{/if}
{/each}
</div>
</div>
<!-- Size indicator (smallest variant) -->
{#if !group.hasMultipleVariants && group.smallestVariant?.storage_size_megabytes}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{formatSize(group.smallestVariant.storage_size_megabytes)}
</span>
{/if}
<!-- Variant count with size range -->
{#if group.hasMultipleVariants}
{@const sizes = group.variants
.map((v) => v.storage_size_megabytes || 0)
.filter((s) => s > 0)
.sort((a, b) => a - b)}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{group.variants.length} variants{#if sizes.length >= 2}{" "}({formatSize(
sizes[0],
)}-{formatSize(sizes[sizes.length - 1])}){/if}
</span>
{/if}
<!-- Download availability indicator -->
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
<span
class="flex-shrink-0"
title={groupDownloadStatus.available
? `Ready — downloaded on ${groupDownloadStatus.nodeNames.join(", ")}`
: `Downloaded on ${groupDownloadStatus.nodeNames.join(", ")} (may need more nodes)`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
<!-- Check mark if selected (single-variant) -->
{#if isMainSelected}
<svg
class="w-4 h-4 text-exo-yellow flex-shrink-0"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z" />
</svg>
{/if}
<!-- Favorite star -->
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0"
onclick={(e) => {
e.stopPropagation();
onToggleFavorite(group.id);
}}
title={isFavorite ? "Remove from favorites" : "Add to favorites"}
>
{#if isFavorite}
<svg
class="w-4 h-4 text-amber-400"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{:else}
<svg
class="w-4 h-4 text-white/30 hover:text-white/50"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{/if}
</button>
<!-- Info button -->
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0"
onclick={(e) => {
e.stopPropagation();
onShowInfo(group);
}}
title="View model details"
>
<svg
class="w-4 h-4 text-white/30 hover:text-white/50"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-6h2v6zm0-8h-2V7h2v2z"
/>
</svg>
</button>
</div>
<!-- Expanded variants -->
{#if isExpanded && group.hasMultipleVariants}
<div class="bg-black/20 border-t border-white/5">
{#each group.variants as variant}
{@const modelCanFit = canModelFit(variant.id)}
{@const isSelected = selectedModelId === variant.id}
<button
type="button"
class="w-full flex items-center gap-3 px-3 py-2 pl-10 hover:bg-white/5 transition-colors text-left {!modelCanFit
? 'opacity-50 cursor-not-allowed'
: 'cursor-pointer'} {isSelected
? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'
: 'border-l-2 border-transparent'}"
disabled={!modelCanFit}
onclick={() => {
if (modelCanFit) {
onSelectModel(variant.id);
}
}}
>
<!-- Quantization badge -->
<span
class="text-xs font-mono px-1.5 py-0.5 rounded bg-white/10 text-white/70 flex-shrink-0"
>
{variant.quantization || "default"}
</span>
<!-- Size -->
<span class="text-xs font-mono text-white/40 flex-1">
{formatSize(variant.storage_size_megabytes)}
</span>
<!-- Download indicator for this variant -->
{#if downloadStatusMap?.get(variant.id)}
{@const variantDl = downloadStatusMap.get(variant.id)}
{#if variantDl}
<span
class="flex-shrink-0"
title={`Downloaded on ${variantDl.nodeNames.join(", ")}`}
>
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{/if}
<!-- Check mark if selected -->
{#if isSelected}
<svg
class="w-4 h-4 text-exo-yellow"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z"
/>
</svg>
{/if}
</button>
{/each}
</div>
{/if}
</div>

View File

@@ -0,0 +1,882 @@
<script lang="ts">
import { fade, fly } from "svelte/transition";
import { cubicOut } from "svelte/easing";
import FamilySidebar from "./FamilySidebar.svelte";
import ModelPickerGroup from "./ModelPickerGroup.svelte";
import ModelFilterPopover from "./ModelFilterPopover.svelte";
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
interface ModelInfo {
id: string;
name?: string;
storage_size_megabytes?: number;
base_model?: string;
quantization?: string;
supports_tensor?: boolean;
capabilities?: string[];
family?: string;
is_custom?: boolean;
tasks?: string[];
hugging_face_id?: string;
}
interface ModelGroup {
id: string;
name: string;
capabilities: string[];
family: string;
variants: ModelInfo[];
smallestVariant: ModelInfo;
hasMultipleVariants: boolean;
}
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
interface HuggingFaceModel {
id: string;
author: string;
downloads: number;
likes: number;
last_modified: string;
tags: string[];
}
type ModelPickerModalProps = {
isOpen: boolean;
models: ModelInfo[];
selectedModelId: string | null;
favorites: Set<string>;
existingModelIds: Set<string>;
canModelFit: (modelId: string) => boolean;
onSelect: (modelId: string) => void;
onClose: () => void;
onToggleFavorite: (baseModelId: string) => void;
onAddModel: (modelId: string) => Promise<void>;
onDeleteModel: (modelId: string) => Promise<void>;
totalMemoryGB: number;
usedMemoryGB: number;
downloadsData?: Record<string, unknown[]>;
topologyNodes?: Record<
string,
{
friendly_name?: string;
system_info?: { model_id?: string };
macmon_info?: { memory?: { ram_total?: number } };
}
>;
};
let {
isOpen,
models,
selectedModelId,
favorites,
existingModelIds,
canModelFit,
onSelect,
onClose,
onToggleFavorite,
onAddModel,
onDeleteModel,
totalMemoryGB,
usedMemoryGB,
downloadsData,
topologyNodes,
}: ModelPickerModalProps = $props();
// Local state
let searchQuery = $state("");
let selectedFamily = $state<string | null>(null);
let expandedGroups = $state<Set<string>>(new Set());
let showFilters = $state(false);
let filters = $state<FilterState>({
capabilities: [],
sizeRange: null,
downloadedOnly: false,
});
let infoGroup = $state<ModelGroup | null>(null);
// Download availability per model group
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
function getNodeName(nodeId: string): string {
const node = topologyNodes?.[nodeId];
return (
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
);
}
const modelDownloadAvailability = $derived.by(() => {
const result = new Map<string, DownloadAvailability>();
if (!downloadsData || !topologyNodes) return result;
for (const model of models) {
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
if (nodeIds.length === 0) continue;
// Sum total RAM across nodes that have the model
let totalRamBytes = 0;
for (const nodeId of nodeIds) {
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
}
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
result.set(model.id, {
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
nodeNames: nodeIds.map(getNodeName),
nodeIds,
});
}
return result;
});
// Aggregate download availability per group (available if ANY variant is available)
function getGroupDownloadAvailability(
group: ModelGroup,
): DownloadAvailability | undefined {
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) return avail;
}
return undefined;
}
// Get per-variant download map for a group
function getVariantDownloadMap(
group: ModelGroup,
): Map<string, DownloadAvailability> {
const map = new Map<string, DownloadAvailability>();
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) map.set(variant.id, avail);
}
return map;
}
// HuggingFace Hub state
let hfSearchQuery = $state("");
let hfSearchResults = $state<HuggingFaceModel[]>([]);
let hfTrendingModels = $state<HuggingFaceModel[]>([]);
let hfIsSearching = $state(false);
let hfIsLoadingTrending = $state(false);
let addingModelId = $state<string | null>(null);
let hfSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;
let manualModelId = $state("");
let addModelError = $state<string | null>(null);
// Reset transient state when modal opens, but preserve tab selection
$effect(() => {
if (isOpen) {
searchQuery = "";
expandedGroups = new Set();
showFilters = false;
manualModelId = "";
addModelError = null;
}
});
// Fetch trending models when HuggingFace is selected
$effect(() => {
if (
selectedFamily === "huggingface" &&
hfTrendingModels.length === 0 &&
!hfIsLoadingTrending
) {
fetchTrendingModels();
}
});
async function fetchTrendingModels() {
hfIsLoadingTrending = true;
try {
const response = await fetch("/models/search?query=&limit=20");
if (response.ok) {
hfTrendingModels = await response.json();
}
} catch (error) {
console.error("Failed to fetch trending models:", error);
} finally {
hfIsLoadingTrending = false;
}
}
async function searchHuggingFace(query: string) {
if (query.length < 2) {
hfSearchResults = [];
return;
}
hfIsSearching = true;
try {
const response = await fetch(
`/models/search?query=${encodeURIComponent(query)}&limit=20`,
);
if (response.ok) {
hfSearchResults = await response.json();
} else {
hfSearchResults = [];
}
} catch (error) {
console.error("Failed to search models:", error);
hfSearchResults = [];
} finally {
hfIsSearching = false;
}
}
function handleHfSearchInput(query: string) {
hfSearchQuery = query;
addModelError = null;
if (hfSearchDebounceTimer) {
clearTimeout(hfSearchDebounceTimer);
}
if (query.length >= 2) {
hfSearchDebounceTimer = setTimeout(() => {
searchHuggingFace(query);
}, 300);
} else {
hfSearchResults = [];
}
}
async function handleAddModel(modelId: string) {
addingModelId = modelId;
addModelError = null;
try {
await onAddModel(modelId);
} catch (error) {
addModelError =
error instanceof Error ? error.message : "Failed to add model";
} finally {
addingModelId = null;
}
}
async function handleAddManualModel() {
if (!manualModelId.trim()) return;
await handleAddModel(manualModelId.trim());
if (!addModelError) {
manualModelId = "";
}
}
function handleSelectHfModel(modelId: string) {
onSelect(modelId);
onClose();
}
// Models to display in HuggingFace view
const hfDisplayModels = $derived.by((): HuggingFaceModel[] => {
if (hfSearchQuery.length >= 2) {
return hfSearchResults;
}
return hfTrendingModels;
});
// Group models by base_model
const groupedModels = $derived.by((): ModelGroup[] => {
const groups = new Map<string, ModelGroup>();
for (const model of models) {
const groupId = model.base_model || model.id;
const groupName = model.base_model || model.name || model.id;
if (!groups.has(groupId)) {
groups.set(groupId, {
id: groupId,
name: groupName,
capabilities: model.capabilities || ["text"],
family: model.family || "",
variants: [],
smallestVariant: model,
hasMultipleVariants: false,
});
}
const group = groups.get(groupId)!;
group.variants.push(model);
// Track smallest variant
if (
(model.storage_size_megabytes || 0) <
(group.smallestVariant.storage_size_megabytes || Infinity)
) {
group.smallestVariant = model;
}
// Update capabilities if not set
if (
group.capabilities.length <= 1 &&
model.capabilities &&
model.capabilities.length > 1
) {
group.capabilities = model.capabilities;
}
if (!group.family && model.family) {
group.family = model.family;
}
}
// Sort variants within each group by size
for (const group of groups.values()) {
group.variants.sort(
(a, b) =>
(a.storage_size_megabytes || 0) - (b.storage_size_megabytes || 0),
);
group.hasMultipleVariants = group.variants.length > 1;
}
// Convert to array and sort by smallest variant size (biggest first)
return Array.from(groups.values()).sort((a, b) => {
return (
(b.smallestVariant.storage_size_megabytes || 0) -
(a.smallestVariant.storage_size_megabytes || 0)
);
});
});
// Get unique families
const uniqueFamilies = $derived.by((): string[] => {
const families = new Set<string>();
for (const group of groupedModels) {
if (group.family) {
families.add(group.family);
}
}
const familyOrder = [
"kimi",
"qwen",
"glm",
"minimax",
"deepseek",
"gpt-oss",
"llama",
];
return Array.from(families).sort((a, b) => {
const aIdx = familyOrder.indexOf(a);
const bIdx = familyOrder.indexOf(b);
if (aIdx === -1 && bIdx === -1) return a.localeCompare(b);
if (aIdx === -1) return 1;
if (bIdx === -1) return -1;
return aIdx - bIdx;
});
});
// Filter models based on search, family, and filters
const filteredGroups = $derived.by((): ModelGroup[] => {
let result: ModelGroup[] = [...groupedModels];
// Filter by family
if (selectedFamily === "favorites") {
result = result.filter((g) => favorites.has(g.id));
} else if (selectedFamily && selectedFamily !== "huggingface") {
result = result.filter((g) => g.family === selectedFamily);
}
// Filter by search query
if (searchQuery.trim()) {
const query = searchQuery.toLowerCase().trim();
result = result.filter(
(g) =>
g.name.toLowerCase().includes(query) ||
g.variants.some(
(v) =>
v.id.toLowerCase().includes(query) ||
(v.name || "").toLowerCase().includes(query),
),
);
}
// Filter by capabilities
if (filters.capabilities.length > 0) {
result = result.filter((g) =>
filters.capabilities.every((cap) => g.capabilities.includes(cap)),
);
}
// Filter by size range
if (filters.sizeRange) {
const { min, max } = filters.sizeRange;
result = result.filter((g) => {
const sizeGB = (g.smallestVariant.storage_size_megabytes || 0) / 1024;
return sizeGB >= min && sizeGB <= max;
});
}
// Filter to downloaded models only
if (filters.downloadedOnly) {
result = result.filter((g) =>
g.variants.some((v) => {
const avail = modelDownloadAvailability.get(v.id);
return avail && avail.nodeIds.length > 0;
}),
);
}
// Sort: models that fit first, then by size (largest first)
result.sort((a, b) => {
const aFits = a.variants.some((v) => canModelFit(v.id));
const bFits = b.variants.some((v) => canModelFit(v.id));
if (aFits && !bFits) return -1;
if (!aFits && bFits) return 1;
return (
(b.smallestVariant.storage_size_megabytes || 0) -
(a.smallestVariant.storage_size_megabytes || 0)
);
});
return result;
});
// Check if any favorites exist
const hasFavorites = $derived(favorites.size > 0);
function toggleGroupExpanded(groupId: string) {
const next = new Set(expandedGroups);
if (next.has(groupId)) {
next.delete(groupId);
} else {
next.add(groupId);
}
expandedGroups = next;
}
function handleSelect(modelId: string) {
onSelect(modelId);
onClose();
}
function handleKeydown(e: KeyboardEvent) {
if (e.key === "Escape") {
onClose();
}
}
function handleFiltersChange(newFilters: FilterState) {
filters = newFilters;
}
function clearFilters() {
filters = { capabilities: [], sizeRange: null, downloadedOnly: false };
}
const hasActiveFilters = $derived(
filters.capabilities.length > 0 ||
filters.sizeRange !== null ||
filters.downloadedOnly,
);
</script>
<svelte:window onkeydown={handleKeydown} />
{#if isOpen}
<!-- Backdrop -->
<div
class="fixed inset-0 z-50 bg-black/80 backdrop-blur-sm"
transition:fade={{ duration: 200 }}
onclick={onClose}
role="presentation"
></div>
<!-- Modal -->
<div
class="fixed z-50 top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(90vw,600px)] h-[min(80vh,700px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl overflow-hidden flex flex-col"
transition:fly={{ y: 20, duration: 300, easing: cubicOut }}
role="dialog"
aria-modal="true"
aria-label="Select a model"
>
<!-- Header with search -->
<div
class="flex items-center gap-2 p-3 border-b border-exo-yellow/10 bg-exo-medium-gray/30"
>
{#if selectedFamily === "huggingface"}
<!-- HuggingFace search -->
<svg
class="w-5 h-5 text-orange-400/60 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<circle cx="11" cy="11" r="8" />
<path d="M21 21l-4.35-4.35" />
</svg>
<input
type="search"
class="flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40"
placeholder="Search mlx-community models..."
value={hfSearchQuery}
oninput={(e) => handleHfSearchInput(e.currentTarget.value)}
/>
{#if hfIsSearching}
<div class="flex-shrink-0">
<span
class="w-4 h-4 border-2 border-orange-400 border-t-transparent rounded-full animate-spin block"
></span>
</div>
{/if}
{:else}
<!-- Normal model search -->
<svg
class="w-5 h-5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<circle cx="11" cy="11" r="8" />
<path d="M21 21l-4.35-4.35" />
</svg>
<input
type="search"
class="flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40"
placeholder="Search models..."
bind:value={searchQuery}
/>
<!-- Cluster memory -->
<span
class="text-xs font-mono flex-shrink-0"
title="Cluster memory usage"
><span class="text-exo-yellow">{Math.round(usedMemoryGB)}GB</span
><span class="text-white/40">/{Math.round(totalMemoryGB)}GB</span
></span
>
<!-- Filter button -->
<div class="relative filter-toggle">
<button
type="button"
class="p-1.5 rounded hover:bg-white/10 transition-colors {hasActiveFilters
? 'text-exo-yellow'
: 'text-white/50'}"
onclick={() => (showFilters = !showFilters)}
title="Filter by capability or size"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="currentColor">
<path d="M10 18h4v-2h-4v2zM3 6v2h18V6H3zm3 7h12v-2H6v2z" />
</svg>
</button>
{#if showFilters}
<ModelFilterPopover
{filters}
onChange={handleFiltersChange}
onClear={clearFilters}
onClose={() => (showFilters = false)}
/>
{/if}
</div>
{/if}
<!-- Close button -->
<button
type="button"
class="p-1.5 rounded hover:bg-white/10 transition-colors text-white/50 hover:text-white/70"
onclick={onClose}
title="Close model picker"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z"
/>
</svg>
</button>
</div>
<!-- Body -->
<div class="flex flex-1 overflow-hidden">
<!-- Family sidebar -->
<FamilySidebar
families={uniqueFamilies}
{selectedFamily}
{hasFavorites}
onSelect={(family) => (selectedFamily = family)}
/>
<!-- Model list -->
<div class="flex-1 overflow-y-auto flex flex-col">
{#if selectedFamily === "huggingface"}
<!-- HuggingFace Hub view -->
<div class="flex-1 flex flex-col min-h-0">
<!-- Section header -->
<div
class="sticky top-0 z-10 px-3 py-2 bg-exo-dark-gray/95 border-b border-exo-yellow/10"
>
<span class="text-xs font-mono text-white/40">
{#if hfSearchQuery.length >= 2}
Search results for "{hfSearchQuery}"
{:else}
Trending on mlx-community
{/if}
</span>
</div>
<!-- Results list -->
<div class="flex-1 overflow-y-auto">
{#if hfIsLoadingTrending && hfTrendingModels.length === 0}
<div
class="flex items-center justify-center py-12 text-white/40"
>
<span
class="w-5 h-5 border-2 border-orange-400 border-t-transparent rounded-full animate-spin mr-2"
></span>
<span class="font-mono text-sm"
>Loading trending models...</span
>
</div>
{:else if hfDisplayModels.length === 0}
<div
class="flex flex-col items-center justify-center py-12 text-white/40"
>
<svg
class="w-10 h-10 mb-2"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 13.5c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm4 0c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm2-4.5H8c0-2.21 1.79-4 4-4s4 1.79 4 4z"
/>
</svg>
<p class="font-mono text-sm">No models found</p>
{#if hfSearchQuery}
<p class="font-mono text-xs mt-1">
Try a different search term
</p>
{/if}
</div>
{:else}
{#each hfDisplayModels as model}
<HuggingFaceResultItem
{model}
isAdded={existingModelIds.has(model.id)}
isAdding={addingModelId === model.id}
onAdd={() => handleAddModel(model.id)}
onSelect={() => handleSelectHfModel(model.id)}
downloadedOnNodes={downloadsData
? getNodesWithModelDownloaded(
downloadsData,
model.id,
).map(getNodeName)
: []}
/>
{/each}
{/if}
</div>
<!-- Manual input footer -->
<div
class="sticky bottom-0 border-t border-exo-yellow/10 bg-exo-dark-gray p-3"
>
{#if addModelError}
<div
class="bg-red-500/10 border border-red-500/30 rounded px-3 py-2 mb-2"
>
<p class="text-red-400 text-xs font-mono break-words">
{addModelError}
</p>
</div>
{/if}
<div class="flex gap-2">
<input
type="text"
class="flex-1 bg-exo-black/60 border border-exo-yellow/30 rounded px-3 py-1.5 text-xs font-mono text-white placeholder-white/30 focus:outline-none focus:border-exo-yellow/50"
placeholder="Or paste model ID directly..."
bind:value={manualModelId}
onkeydown={(e) => {
if (e.key === "Enter") handleAddManualModel();
}}
/>
<button
type="button"
onclick={handleAddManualModel}
disabled={!manualModelId.trim() || addingModelId !== null}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded disabled:opacity-50 disabled:cursor-not-allowed"
>
Add
</button>
</div>
</div>
</div>
{:else if filteredGroups.length === 0}
<div
class="flex flex-col items-center justify-center h-full text-white/40 p-8"
>
<svg class="w-12 h-12 mb-3" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z"
/>
</svg>
<p class="font-mono text-sm">No models found</p>
{#if hasActiveFilters || searchQuery}
<button
type="button"
class="mt-2 text-xs text-exo-yellow hover:underline"
onclick={() => {
searchQuery = "";
clearFilters();
}}
>
Clear filters
</button>
{/if}
</div>
{:else}
{#each filteredGroups as group}
<ModelPickerGroup
{group}
isExpanded={expandedGroups.has(group.id)}
isFavorite={favorites.has(group.id)}
{selectedModelId}
{canModelFit}
onToggleExpand={() => toggleGroupExpanded(group.id)}
onSelectModel={handleSelect}
{onToggleFavorite}
onShowInfo={(g) => (infoGroup = g)}
downloadStatusMap={getVariantDownloadMap(group)}
/>
{/each}
{/if}
</div>
</div>
<!-- Footer with active filters indicator -->
{#if hasActiveFilters}
<div
class="flex items-center gap-2 px-3 py-2 border-t border-exo-yellow/10 bg-exo-medium-gray/20 text-xs font-mono text-white/50"
>
<span>Filters:</span>
{#each filters.capabilities as cap}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded"
>{cap}</span
>
{/each}
{#if filters.downloadedOnly}
<span class="px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded"
>Downloaded</span
>
{/if}
{#if filters.sizeRange}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
</span>
{/if}
<button
type="button"
class="ml-auto text-white/40 hover:text-white/60"
onclick={clearFilters}
>
Clear all
</button>
</div>
{/if}
</div>
<!-- Info modal -->
{#if infoGroup}
<div
class="fixed inset-0 z-[60] bg-black/60"
transition:fade={{ duration: 150 }}
onclick={() => (infoGroup = null)}
role="presentation"
></div>
<div
class="fixed z-[60] top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(80vw,400px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl p-4"
transition:fly={{ y: 10, duration: 200, easing: cubicOut }}
role="dialog"
aria-modal="true"
>
<div class="flex items-start justify-between mb-3">
<h3 class="font-mono text-lg text-white">{infoGroup.name}</h3>
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors text-white/50"
onclick={() => (infoGroup = null)}
title="Close model details"
aria-label="Close info dialog"
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z"
/>
</svg>
</button>
</div>
<div class="space-y-2 text-xs font-mono">
<div class="flex items-center gap-2">
<span class="text-white/40">Family:</span>
<span class="text-white/70">{infoGroup.family || "Unknown"}</span>
</div>
<div class="flex items-center gap-2">
<span class="text-white/40">Capabilities:</span>
<span class="text-white/70">{infoGroup.capabilities.join(", ")}</span>
</div>
<div class="flex items-center gap-2">
<span class="text-white/40">Variants:</span>
<span class="text-white/70">{infoGroup.variants.length}</span>
</div>
{#if infoGroup.variants.length > 0}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<span class="text-white/40">Available quantizations:</span>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoGroup.variants as variant}
<span
class="px-1.5 py-0.5 bg-white/10 text-white/60 rounded text-[10px]"
>
{variant.quantization || "default"} ({Math.round(
(variant.storage_size_megabytes || 0) / 1024,
)}GB)
</span>
{/each}
</div>
</div>
{/if}
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
{#if infoDownload}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<div class="flex items-center gap-2 mb-1">
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="text-white/40">Downloaded on:</span>
</div>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoDownload.nodeNames as nodeName}
<span
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
>
{nodeName}
</span>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>
{/if}
{/if}

View File

@@ -0,0 +1,236 @@
<script lang="ts">
import type { TokenData } from "$lib/stores/app.svelte";
interface Props {
tokens: TokenData[];
class?: string;
isGenerating?: boolean;
onRegenerateFrom?: (tokenIndex: number) => void;
}
let {
tokens,
class: className = "",
isGenerating = false,
onRegenerateFrom,
}: Props = $props();
// Tooltip state - track both token data and index
let hoveredTokenIndex = $state<number | null>(null);
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
let isTooltipHovered = $state(false);
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
// Derive the hovered token from the index (stable across re-renders)
const hoveredToken = $derived(
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
? {
token: tokens[hoveredTokenIndex],
index: hoveredTokenIndex,
...hoveredPosition,
}
: null,
);
/**
* Get confidence styling based on probability.
* Following Apple design principles: high confidence tokens blend in,
* only uncertainty draws attention.
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return "text-inherit"; // Expected tokens - blend in
if (probability > 0.5) return "bg-gray-500/10 text-inherit"; // Slight hint
if (probability > 0.2) return "bg-amber-500/15 text-amber-200/90"; // Subtle warmth
return "bg-red-500/20 text-red-200/90"; // Draws attention
}
/**
* Get border/underline styling for uncertain tokens
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return "border-transparent"; // No border for expected
if (probability > 0.5) return "border-gray-500/20";
if (probability > 0.2) return "border-amber-500/30";
return "border-red-500/40";
}
function clearHideTimeout() {
if (hideTimeoutId) {
clearTimeout(hideTimeoutId);
hideTimeoutId = null;
}
}
function handleMouseEnter(
event: MouseEvent,
token: TokenData,
index: number,
) {
clearHideTimeout();
const rects = (event.target as HTMLElement).getClientRects();
let rect = rects[0];
for (let j = 0; j < rects.length; j++) {
if (event.clientY >= rects[j].top && event.clientY <= rects[j].bottom) {
rect = rects[j];
break;
}
}
hoveredTokenIndex = index;
hoveredPosition = {
x: rect.left + rect.width / 2,
y: rect.top - 10,
};
}
function handleMouseLeave() {
clearHideTimeout();
// Use longer delay during generation to account for re-renders
const delay = isGenerating ? 300 : 200;
hideTimeoutId = setTimeout(() => {
if (!isTooltipHovered) {
hoveredTokenIndex = null;
hoveredPosition = null;
}
}, delay);
}
function handleTooltipEnter() {
clearHideTimeout();
isTooltipHovered = true;
}
function handleTooltipLeave() {
isTooltipHovered = false;
hoveredTokenIndex = null;
hoveredPosition = null;
}
function handleRegenerate() {
if (hoveredToken && onRegenerateFrom) {
const indexToRegenerate = hoveredToken.index;
// Clear hover state immediately
hoveredTokenIndex = null;
hoveredPosition = null;
isTooltipHovered = false;
// Call regenerate
onRegenerateFrom(indexToRegenerate);
}
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + "%";
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
function getProbabilityColor(probability: number): string {
if (probability > 0.8) return "text-gray-300";
if (probability > 0.5) return "text-gray-400";
if (probability > 0.2) return "text-amber-400";
return "text-red-400";
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(
tokenData.probability,
)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
onmouseleave={handleMouseLeave}>{tokenData.token}</span
>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50 pb-2"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
onmouseenter={handleTooltipEnter}
onmouseleave={handleTooltipLeave}
>
<div
class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48"
>
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-500 text-xs">Token:</span>
<span class="text-white font-mono ml-1"
>"{hoveredToken.token.token}"</span
>
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2"
>{formatProbability(hoveredToken.token.probability)}</span
>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono"
>{formatLogprob(hoveredToken.token.logprob)}</span
>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700/50 mt-2 pt-2">
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24"
>"{alt.token}"</span
>
<span class="text-gray-400 ml-2"
>{formatProbability(altProb)}</span
>
</div>
{/each}
</div>
{/if}
<!-- Regenerate button -->
{#if onRegenerateFrom}
<button
onclick={handleRegenerate}
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
/>
</svg>
Regenerate from here
</button>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -6,3 +6,9 @@ export { default as ChatSidebar } from "./ChatSidebar.svelte";
export { default as ModelCard } from "./ModelCard.svelte";
export { default as MarkdownContent } from "./MarkdownContent.svelte";
export { default as ImageParamsPanel } from "./ImageParamsPanel.svelte";
export { default as FamilyLogos } from "./FamilyLogos.svelte";
export { default as FamilySidebar } from "./FamilySidebar.svelte";
export { default as HuggingFaceResultItem } from "./HuggingFaceResultItem.svelte";
export { default as ModelFilterPopover } from "./ModelFilterPopover.svelte";
export { default as ModelPickerGroup } from "./ModelPickerGroup.svelte";
export { default as ModelPickerModal } from "./ModelPickerModal.svelte";

View File

@@ -242,6 +242,19 @@ export interface MessageAttachment {
mimeType?: string;
}
export interface TopLogprob {
token: string;
logprob: number;
bytes: number[] | null;
}
export interface TokenData {
token: string;
logprob: number;
probability: number;
topLogprobs: TopLogprob[];
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -253,6 +266,7 @@ export interface Message {
tps?: number; // Tokens per second (for assistant messages)
requestType?: "chat" | "image-generation" | "image-editing";
sourceImageDataUrl?: string; // For image editing regeneration
tokens?: TokenData[];
}
export interface Conversation {
@@ -540,7 +554,18 @@ class AppStore {
*/
private saveConversationsToStorage() {
try {
localStorage.setItem(STORAGE_KEY, JSON.stringify(this.conversations));
// Strip tokens from messages before saving to avoid bloating localStorage
const stripped = this.conversations.map((conv) => ({
...conv,
messages: conv.messages.map((msg) => {
if (msg.tokens) {
const { tokens: _, ...rest } = msg;
return rest;
}
return msg;
}),
}));
localStorage.setItem(STORAGE_KEY, JSON.stringify(stripped));
} catch (error) {
console.error("Failed to save conversations:", error);
}
@@ -1445,6 +1470,213 @@ class AppStore {
}
}
/**
* Regenerate response from a specific token index.
* Truncates the assistant message at the given token and re-generates from there.
*/
async regenerateFromToken(
messageId: string,
tokenIndex: number,
): Promise<void> {
if (this.isLoading) return;
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
const msgIndex = this.messages.findIndex((m) => m.id === messageId);
if (msgIndex === -1) return;
const msg = this.messages[msgIndex];
if (
msg.role !== "assistant" ||
!msg.tokens ||
tokenIndex >= msg.tokens.length
)
return;
// Keep tokens up to (not including) the specified index
const tokensToKeep = msg.tokens.slice(0, tokenIndex);
const prefixText = tokensToKeep.map((t) => t.token).join("");
// Remove all messages after this assistant message
this.messages = this.messages.slice(0, msgIndex + 1);
// Update the message to show the prefix
this.messages[msgIndex].content = prefixText;
this.messages[msgIndex].tokens = tokensToKeep;
this.updateActiveConversation();
// Set up for continuation - modify the existing message in place
this.isLoading = true;
this.currentResponse = prefixText;
this.ttftMs = null;
this.tps = null;
this.totalTokens = tokensToKeep.length;
try {
// Build messages for API - include the partial assistant message
const systemPrompt = {
role: "system" as const,
content:
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
};
const apiMessages = [
systemPrompt,
...this.messages.map((m) => {
let msgContent = m.content;
if (m.attachments) {
for (const attachment of m.attachments) {
if (attachment.type === "text" && attachment.content) {
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
}
}
}
return { role: m.role, content: msgContent };
}),
];
const modelToUse = this.getModelForRequest();
if (!modelToUse) {
throw new Error("No model available");
}
const requestStartTime = performance.now();
let firstTokenTime: number | null = null;
let tokenCount = tokensToKeep.length;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`API error: ${response.status} - ${errorText}`);
}
const reader = response.body?.getReader();
if (!reader) throw new Error("No response body");
let fullContent = prefixText;
const collectedTokens: TokenData[] = [...tokensToKeep];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
}
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
if (delta) {
if (firstTokenTime === null) {
firstTokenTime = performance.now();
this.ttftMs = firstTokenTime - requestStartTime;
}
tokenCount += 1;
this.totalTokens = tokenCount;
if (firstTokenTime !== null && tokenCount > tokensToKeep.length) {
const elapsed = performance.now() - firstTokenTime;
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
}
fullContent += delta;
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
if (this.activeConversationId === targetConversationId) {
this.currentResponse = displayContent;
}
// Update existing message in place
this.updateConversationMessage(
targetConversationId,
messageId,
(m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
},
);
// Final update
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.tokens = [...collectedTokens];
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
if (this.tps !== null) m.tps = this.tps;
});
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
} catch (error) {
console.error("Error regenerating from token:", error);
if (this.conversationExists(targetConversationId)) {
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = `${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
});
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
} finally {
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
}
}
/**
* Helper method to regenerate a chat completion response
*/
@@ -1513,6 +1745,8 @@ class AppStore {
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -1527,16 +1761,49 @@ class AppStore {
}
let streamedContent = "";
const collectedTokens: TokenData[] = [];
interface ChatCompletionChunk {
choices?: Array<{ delta?: { content?: string } }>;
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
}
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const delta = parsed.choices?.[0]?.delta?.content;
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
if (delta) {
streamedContent += delta;
const { displayContent, thinkingContent } =
@@ -1554,6 +1821,7 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -1572,6 +1840,7 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -1914,6 +2183,8 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -1930,14 +2201,48 @@ class AppStore {
let streamedContent = "";
interface ChatCompletionChunk {
choices?: Array<{ delta?: { content?: string } }>;
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
}
const collectedTokens: TokenData[] = [];
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const tokenContent = parsed.choices?.[0]?.delta?.content;
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
if (tokenContent) {
// Track first token for TTFT
if (firstTokenTime === null) {
@@ -1973,6 +2278,7 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -1997,6 +2303,7 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
// Store performance metrics on the message
if (this.ttftMs !== null) {
msg.ttftMs = this.ttftMs;
@@ -2693,6 +3000,8 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
appStore.regenerateFromToken(messageId, tokenIndex);
// Conversation actions
export const conversations = () => appStore.conversations;

View File

@@ -0,0 +1,97 @@
/**
* FavoritesStore - Manages favorite models with localStorage persistence
*/
import { browser } from "$app/environment";
const FAVORITES_KEY = "exo-favorite-models";
class FavoritesStore {
favorites = $state<Set<string>>(new Set());
constructor() {
if (browser) {
this.loadFromStorage();
}
}
private loadFromStorage() {
try {
const stored = localStorage.getItem(FAVORITES_KEY);
if (stored) {
const parsed = JSON.parse(stored) as string[];
this.favorites = new Set(parsed);
}
} catch (error) {
console.error("Failed to load favorites:", error);
}
}
private saveToStorage() {
try {
const array = Array.from(this.favorites);
localStorage.setItem(FAVORITES_KEY, JSON.stringify(array));
} catch (error) {
console.error("Failed to save favorites:", error);
}
}
add(baseModelId: string) {
const next = new Set(this.favorites);
next.add(baseModelId);
this.favorites = next;
this.saveToStorage();
}
remove(baseModelId: string) {
const next = new Set(this.favorites);
next.delete(baseModelId);
this.favorites = next;
this.saveToStorage();
}
toggle(baseModelId: string) {
if (this.favorites.has(baseModelId)) {
this.remove(baseModelId);
} else {
this.add(baseModelId);
}
}
isFavorite(baseModelId: string): boolean {
return this.favorites.has(baseModelId);
}
getAll(): string[] {
return Array.from(this.favorites);
}
getSet(): Set<string> {
return new Set(this.favorites);
}
hasAny(): boolean {
return this.favorites.size > 0;
}
clearAll() {
this.favorites = new Set();
this.saveToStorage();
}
}
export const favoritesStore = new FavoritesStore();
export const favorites = () => favoritesStore.favorites;
export const hasFavorites = () => favoritesStore.hasAny();
export const isFavorite = (baseModelId: string) =>
favoritesStore.isFavorite(baseModelId);
export const toggleFavorite = (baseModelId: string) =>
favoritesStore.toggle(baseModelId);
export const addFavorite = (baseModelId: string) =>
favoritesStore.add(baseModelId);
export const removeFavorite = (baseModelId: string) =>
favoritesStore.remove(baseModelId);
export const getFavorites = () => favoritesStore.getAll();
export const getFavoritesSet = () => favoritesStore.getSet();
export const clearFavorites = () => favoritesStore.clearAll();

View File

@@ -0,0 +1,152 @@
/**
* Shared utilities for parsing and querying download state.
*
* The download state from `/state` is shaped as:
* Record<NodeId, Array<TaggedDownloadEntry>>
*
* Each entry is a tagged union object like:
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
*/
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
function unwrapTagged(
obj: Record<string, unknown>,
): [string, Record<string, unknown>] | null {
const keys = Object.keys(obj);
if (keys.length !== 1) return null;
const tag = keys[0];
const payload = obj[tag];
if (!payload || typeof payload !== "object") return null;
return [tag, payload as Record<string, unknown>];
}
/** Extract the model ID string from a download entry's nested shard_metadata. */
export function extractModelIdFromDownload(
downloadPayload: Record<string, unknown>,
): string | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
if (!unwrapped) return null;
const [, shardData] = unwrapped;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== "object") return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
}
/** Extract the shard_metadata object from a download entry payload. */
export function extractShardMetadata(
downloadPayload: Record<string, unknown>,
): Record<string, unknown> | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
return shardMetadata as Record<string, unknown>;
}
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
export function getDownloadTag(
entry: unknown,
): [string, Record<string, unknown>] | null {
if (!entry || typeof entry !== "object") return null;
return unwrapTagged(entry as Record<string, unknown>);
}
/**
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
*/
function* iterNodeDownloads(
nodeDownloads: unknown[],
): Generator<[string, Record<string, unknown>, string]> {
for (const entry of nodeDownloads) {
const tagged = getDownloadTag(entry);
if (!tagged) continue;
const [tag, payload] = tagged;
const modelId = extractModelIdFromDownload(payload);
if (!modelId) continue;
yield [tag, payload, modelId];
}
}
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
export function isModelDownloadedOnNode(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): boolean {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return false;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
}
return false;
}
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
export function getNodesWithModelDownloaded(
downloadsData: Record<string, unknown[]>,
modelId: string,
): string[] {
const result: string[] = [];
for (const nodeId of Object.keys(downloadsData)) {
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
result.push(nodeId);
}
}
return result;
}
/**
* Find shard metadata for a model from any download entry across all nodes.
* Returns the first match found (completed entries are preferred).
*/
export function getShardMetadataForModel(
downloadsData: Record<string, unknown[]>,
modelId: string,
): Record<string, unknown> | null {
let fallback: Record<string, unknown> | null = null;
for (const nodeDownloads of Object.values(downloadsData)) {
if (!Array.isArray(nodeDownloads)) continue;
for (const [tag, payload, entryModelId] of iterNodeDownloads(
nodeDownloads,
)) {
if (entryModelId !== modelId) continue;
const shard = extractShardMetadata(payload);
if (!shard) continue;
if (tag === "DownloadCompleted") return shard;
if (!fallback) fallback = shard;
}
}
return fallback;
}
/**
* Get the download status tag for a specific model on a specific node.
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
*/
export function getModelDownloadStatus(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): string | null {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return null;
let best: string | null = null;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (entryModelId !== modelId) continue;
if (tag === "DownloadCompleted") return tag;
if (tag === "DownloadOngoing") best = tag;
else if (!best) best = tag;
}
return best;
}

View File

@@ -5,7 +5,13 @@
ChatMessages,
ChatSidebar,
ModelCard,
ModelPickerModal,
} from "$lib/components";
import {
favorites,
toggleFavorite,
getFavoritesSet,
} from "$lib/stores/favorites.svelte";
import {
hasStartedChat,
isTopologyMinimized,
@@ -100,6 +106,11 @@
storage_size_megabytes?: number;
tasks?: string[];
hugging_face_id?: string;
is_custom?: boolean;
family?: string;
quantization?: string;
base_model?: string;
capabilities?: string[];
}>
>([]);
@@ -211,9 +222,11 @@
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state("");
// Model picker modal state
let isModelPickerOpen = $state(false);
// Favorites state (reactive)
const favoritesSet = $derived(getFavoritesSet());
// Slider dragging state
let isDraggingSlider = $state(false);
@@ -530,6 +543,47 @@
}
}
async function addModelFromPicker(modelId: string) {
const response = await fetch("/models/add", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ model_id: modelId }),
});
if (!response.ok) {
let message = `Failed to add model (${response.status}: ${response.statusText})`;
try {
const err = await response.json();
if (err.detail) message = err.detail;
} catch {
// use default message
}
throw new Error(message);
}
await fetchModels();
}
async function deleteCustomModel(modelId: string) {
try {
const response = await fetch(
`/models/custom/${encodeURIComponent(modelId)}`,
{ method: "DELETE" },
);
if (response.ok) {
await fetchModels();
}
} catch {
console.error("Failed to delete custom model");
}
}
function handleModelPickerSelect(modelId: string) {
selectPreviewModel(modelId);
saveLaunchDefaults();
isModelPickerOpen = false;
}
async function launchInstance(
modelId: string,
specificPreview?: PlacementPreview | null,
@@ -2360,14 +2414,12 @@
>
</div>
<!-- Model Dropdown (Custom) -->
<div class="flex-shrink-0 mb-3 relative">
<!-- Model Picker Button -->
<div class="flex-shrink-0 mb-3">
<button
type="button"
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen
? 'border-exo-yellow/70'
: ''}"
onclick={() => (isModelPickerOpen = true)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 relative"
>
{#if selectedModelId}
{@const foundModel = models.find(
@@ -2375,54 +2427,12 @@
)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
{@const isImageModel = modelSupportsImageGeneration(
foundModel.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
foundModel.id,
)}
<span
class="flex items-center justify-between gap-2 w-full pr-4"
>
<span
class="flex items-center gap-2 text-exo-light-gray truncate"
>
{#if isImageModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect
x="3"
y="3"
width="18"
height="18"
rx="2"
ry="2"
/>
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate"
>{foundModel.name || foundModel.id}</span
>
@@ -2439,142 +2449,24 @@
{:else}
<span class="text-white/50"> SELECT MODEL </span>
{/if}
</button>
<div
class="absolute right-3 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isModelDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-4 h-4 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
{#if isModelDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-40 cursor-default"
onclick={() => (isModelDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel -->
<div
class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-50 max-h-64 overflow-y-auto"
class="absolute right-3 top-1/2 -translate-y-1/2 pointer-events-none"
>
<!-- Search within dropdown -->
<div
class="sticky top-0 bg-exo-dark-gray border-b border-exo-medium-gray/30 p-2"
<svg
class="w-4 h-4 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<input
type="text"
placeholder="Search models..."
bind:value={modelDropdownSearch}
class="w-full bg-exo-dark-gray/60 border border-exo-medium-gray/30 rounded px-2 py-1.5 text-xs font-mono text-white/80 placeholder:text-white/40 focus:outline-none focus:border-exo-yellow/50"
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</div>
<!-- Options -->
<div class="py-1">
{#each sortedModels().filter((m) => !modelDropdownSearch || (m.name || m.id)
.toLowerCase()
.includes(modelDropdownSearch.toLowerCase())) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(
model.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
model.id,
)}
<button
type="button"
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = "";
}
}}
disabled={!modelCanFit}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedModelId ===
model.id
? 'bg-transparent text-exo-yellow cursor-pointer'
: modelCanFit
? 'text-white/80 hover:text-exo-yellow cursor-pointer'
: 'text-white/30 cursor-default'}"
>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image generation model"
>
<rect
x="3"
y="3"
width="18"
height="18"
rx="2"
ry="2"
/>
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image editing model"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span
class="flex-shrink-0 text-xs {modelCanFit
? 'text-white/50'
: 'text-red-400/60'}"
>
{sizeGB >= 1
? sizeGB.toFixed(0)
: sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">
No models found
</div>
{/each}
</div>
</svg>
</div>
{/if}
</button>
</div>
<!-- Configuration Options -->
@@ -3354,3 +3246,24 @@
{/if}
</main>
</div>
<ModelPickerModal
isOpen={isModelPickerOpen}
{models}
{selectedModelId}
favorites={favoritesSet}
existingModelIds={new Set(models.map((m) => m.id))}
canModelFit={(modelId) => {
const model = models.find((m) => m.id === modelId);
return model ? hasEnoughMemory(model) : false;
}}
onSelect={handleModelPickerSelect}
onClose={() => (isModelPickerOpen = false)}
onToggleFavorite={toggleFavorite}
onAddModel={addModelFromPicker}
onDeleteModel={deleteCustomModel}
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
{downloadsData}
topologyNodes={data?.nodes}
/>

View File

@@ -118,9 +118,10 @@
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit (self'.packages) metal-toolchain;
inherit uvLockMlxVersion;
};
default = self'.packages.exo;
}
);

View File

@@ -10,6 +10,7 @@ PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
RESOURCES_DIR = PROJECT_ROOT / "resources"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
@@ -18,6 +19,9 @@ if not ENTRYPOINT.is_file():
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not RESOURCES_DIR.is_dir():
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
@@ -58,6 +62,7 @@ HIDDEN_IMPORTS = sorted(
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(RESOURCES_DIR), "resources"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]

View File

@@ -19,7 +19,7 @@ dependencies = [
"anyio==4.11.0",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -63,7 +63,6 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }

View File

@@ -69,7 +69,8 @@
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard} \
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
--set EXO_RESOURCES_DIR ${inputs.self + "/resources"} \
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
done
'';

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15470210592
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5945589280
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21415799872
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11891178560
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33306978432
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23782357120
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,36 @@
model_id = "exolabs/Qwen-Image-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
uses_cfg = true
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,36 @@
model_id = "exolabs/Qwen-Image-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
uses_cfg = true
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,36 @@
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
uses_cfg = true
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,36 @@
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
uses_cfg = true
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,36 @@
model_id = "exolabs/Qwen-Image-Edit-2509"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
uses_cfg = true
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,36 @@
model_id = "exolabs/Qwen-Image"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
uses_cfg = true
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/DeepSeek-V3.1-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405874409472

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/DeepSeek-V3.1-8bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 765577920512

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.5-Air-8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 122406567936

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.5-Air-bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 229780750336

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 198556925568

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 286737579648

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-8bit-gs32"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 396963397248

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-Flash-4bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 19327352832

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-Flash-5bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 22548578304

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-Flash-6bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 26843545600

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-4.7-Flash-8bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 34359738368

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = "4bit"
base_model = "Kimi K2"
capabilities = ["text"]
[storage_size]
in_bytes = 620622774272

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Kimi-K2-Thinking"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 706522120192

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Kimi-K2.5"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 662498705408

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.2 1B"
capabilities = ["text"]
[storage_size]
in_bytes = 729808896

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.2 3B"
capabilities = ["text"]
[storage_size]
in_bytes = 1863319552

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.2 3B"
capabilities = ["text"]
[storage_size]
in_bytes = 3501195264

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 76799803392

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.1 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 4637851648

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 8954839040

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "bf16"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 16882073600

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.1-3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 100086644736

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.1-8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-0.6B-4bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 342884352

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-0.6B-8bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 698351616

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 141733920768

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 268435456000

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 17612931072

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 33279705088

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Coder 480B"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 289910292480

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Coder 480B"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 579820584960

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 45644286500

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 57657697020

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 68899327465

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 89357758772

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 157548627945

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text"]
[storage_size]
in_bytes = 46976204800

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text"]
[storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 47080074240

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
n_layers = 36
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
family = "gpt-oss"
quantization = "MXFP4-Q8"
base_model = "GPT-OSS 120B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 70652212224

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
n_layers = 24
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
family = "gpt-oss"
quantization = "MXFP4-Q8"
base_model = "GPT-OSS 20B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 12025908224

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "fp16"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 144383672320

View File

@@ -1,4 +1,5 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
@@ -52,18 +53,44 @@ class DownloadCoordinator:
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
async def _check_internet_connection(self) -> None:
first_connection = True
while True:
await asyncio.sleep(10)
# Assume that internet connection is set to False on 443 errors.
if self.shard_downloader.internet_connection:
continue
self._test_internet_connection()
if first_connection and self.shard_downloader.internet_connection:
first_connection = False
self._tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
@@ -241,7 +268,7 @@ class DownloadCoordinator:
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info(
logger.debug(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
@@ -274,10 +301,10 @@ class DownloadCoordinator:
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info(
logger.debug(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(5 * 60) # 5 minutes
await anyio.sleep(60)
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"

View File

@@ -49,6 +49,10 @@ class HuggingFaceAuthenticationError(Exception):
"""Raised when HuggingFace returns 401/403 for a model download."""
class HuggingFaceRateLimitError(Exception):
"""429 Huggingface code"""
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
token = await get_hf_token()
if status_code == 401 and token is None:
@@ -154,49 +158,76 @@ async def seed_models(seed_dir: str | Path):
logger.error(traceback.format_exc())
_fetched_file_lists_this_session: set[str] = set()
async def fetch_file_list_with_cache(
model_id: ModelId, revision: str = "main", recursive: bool = False
model_id: ModelId,
revision: str = "main",
recursive: bool = False,
skip_internet: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
) -> list[FileListEntry]:
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
cache_key = f"{model_id.normalize()}--{revision}"
if cache_key in _fetched_file_lists_this_session and await aios.path.exists(
cache_file
):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
if skip_internet:
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
raise FileNotFoundError(
f"No internet connection and no cached file list for {model_id}"
)
# Always try fresh first
try:
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
model_id,
revision,
recursive=recursive,
on_connection_lost=on_connection_lost,
)
# Update cache with fresh data
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
_fetched_file_lists_this_session.add(cache_key)
return file_list
except Exception as e:
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
# No cache available, propagate the error
raise
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
async def fetch_file_list_with_retry(
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId,
revision: str = "main",
path: str = "",
recursive: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
) -> list[FileListEntry]:
n_attempts = 30
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
await asyncio.sleep(2.0**attempt)
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
@@ -216,7 +247,11 @@ async def _fetch_file_list(
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
raise HuggingFaceAuthenticationError(msg)
if response.status == 200:
elif response.status == 429:
raise HuggingFaceRateLimitError(
f"Couldn't download {model_id} because of HuggingFace rate limit."
)
elif response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
@@ -249,7 +284,7 @@ def create_http_session(
else:
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 1800
sock_read_timeout = 60
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(
@@ -324,8 +359,9 @@ async def download_file_with_retry(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> Path:
n_attempts = 30
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _download_file(
@@ -333,14 +369,19 @@ async def download_file_with_retry(
)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
except HuggingFaceRateLimitError as e:
if attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
await asyncio.sleep(2.0**attempt)
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
break
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
@@ -542,7 +583,9 @@ async def download_shard(
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
max_parallel_downloads: int = 8,
skip_download: bool = False,
skip_internet: bool = False,
allow_patterns: list[str] | None = None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.debug(f"Downloading {shard.model_card.model_id=}")
@@ -562,7 +605,11 @@ async def download_shard(
all_start_time = time.time()
file_list = await fetch_file_list_with_cache(
shard.model_card.model_id, revision, recursive=True
shard.model_card.model_id,
revision,
recursive=True,
skip_internet=skip_internet,
on_connection_lost=on_connection_lost,
)
filtered_file_list = list(
filter_repo_objects(
@@ -672,6 +719,7 @@ async def download_shard(
lambda curr_bytes, total_bytes, is_renamed: schedule_progress(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
)
if not skip_download:

View File

@@ -1,4 +1,5 @@
import asyncio
from asyncio import create_task
from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
@@ -7,7 +8,7 @@ from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -49,6 +50,10 @@ class SingletonShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -85,6 +90,10 @@ class CachedShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -142,6 +151,8 @@ class ResumableShardDownloader(ShardDownloader):
self.on_progress_wrapper,
max_parallel_downloads=self.max_parallel_downloads,
allow_patterns=allow_patterns,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
)
return target_dir
@@ -154,13 +165,24 @@ class ResumableShardDownloader(ShardDownloader):
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)
return await download_shard(
shard, self.on_progress_wrapper, skip_download=True
shard,
self.on_progress_wrapper,
skip_download=True,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
)
# Kick off download status coroutines concurrently
semaphore = asyncio.Semaphore(self.max_parallel_downloads)
async def download_with_semaphore(
model_card: ModelCard,
) -> tuple[Path, RepoDownloadProgress]:
async with semaphore:
return await _status_for_model(model_card.model_id)
tasks = [
asyncio.create_task(_status_for_model(model_card.model_id))
for model_card in MODEL_CARDS.values()
create_task(download_with_semaphore(model_card))
for model_card in await get_model_cards()
]
for task in asyncio.as_completed(tasks):

View File

@@ -16,6 +16,11 @@ from exo.shared.types.worker.shards import (
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
class ShardDownloader(ABC):
internet_connection: bool = False
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
@abstractmethod
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False

View File

@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.main import Worker
# I marked this as a dataclass as I want trivial constructors.
@dataclass
class Node:
router: Router
@@ -136,7 +135,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -148,6 +146,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit

View File

@@ -14,6 +14,8 @@ from exo.shared.types.api import (
ErrorInfo,
ErrorResponse,
FinishReason,
Logprobs,
LogprobsContentItem,
StreamingChoiceResponse,
ToolCall,
)
@@ -66,7 +68,9 @@ def chat_request_to_text_generation(
return TextGenerationTaskParams(
model=request.model,
input=input_messages if input_messages else "",
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
@@ -79,6 +83,8 @@ def chat_request_to_text_generation(
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
logprobs=request.logprobs or False,
top_logprobs=request.top_logprobs,
)
@@ -86,6 +92,19 @@ def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -94,6 +113,7 @@ def chunk_to_response(
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
@@ -160,6 +180,7 @@ async def collect_chat_response(
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
logprobs_content: list[LogprobsContentItem] = []
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
@@ -174,6 +195,14 @@ async def collect_chat_response(
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
@@ -206,6 +235,9 @@ async def collect_chat_response(
content=combined_text,
tool_calls=tool_calls if tool_calls else None,
),
logprobs=Logprobs(content=logprobs_content)
if logprobs_content
else None,
finish_reason=finish_reason,
)
],

View File

@@ -141,7 +141,9 @@ def claude_request_to_text_generation(
return TextGenerationTaskParams(
model=request.model,
input=input_messages if input_messages else "",
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,

View File

@@ -43,10 +43,10 @@ def _extract_content(content: str | list[ResponseContentPart]) -> str:
def responses_request_to_text_generation(
request: ResponsesRequest,
) -> TextGenerationTaskParams:
input_value: str | list[InputMessage]
input_value: list[InputMessage]
built_chat_template: list[dict[str, Any]] | None = None
if isinstance(request.input, str):
input_value = request.input
input_value = [InputMessage(role="user", content=request.input)]
else:
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
@@ -95,7 +95,11 @@ def responses_request_to_text_generation(
}
)
input_value = input_messages if input_messages else ""
input_value = (
input_messages
if input_messages
else [InputMessage(role="user", content="")]
)
built_chat_template = chat_template_messages if chat_template_messages else None
return TextGenerationTaskParams(

View File

@@ -1,6 +1,7 @@
import base64
import contextlib
import json
import random
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from datetime import datetime, timezone
@@ -40,6 +41,7 @@ from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import (
DASHBOARD_DIR,
EXO_IMAGE_CACHE_DIR,
EXO_MAX_CHUNK_SIZE,
EXO_TRACING_CACHE_DIR,
@@ -47,12 +49,15 @@ from exo.shared.constants import (
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
delete_custom_card,
get_model_cards,
is_custom_card,
)
from exo.shared.tracing import TraceEvent, compute_stats, export_trace, load_trace_file
from exo.shared.types.api import (
AddCustomModelParams,
AdvancedImageParams,
BenchChatCompletionRequest,
BenchChatCompletionResponse,
@@ -70,6 +75,7 @@ from exo.shared.types.api import (
ErrorResponse,
FinishReason,
GenerationStats,
HuggingFaceSearchResult,
ImageData,
ImageEditsTaskParams,
ImageGenerationResponse,
@@ -138,7 +144,6 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
@@ -146,16 +151,13 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
"""Ensure advanced params has a seed set for distributed consistency."""
if params is None:
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
if params.seed is None:
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
return params
class API:
@@ -204,7 +206,7 @@ class API:
self.app.mount(
"/",
StaticFiles(
directory=find_dashboard(),
directory=DASHBOARD_DIR,
html=True,
),
name="dashboard",
@@ -269,6 +271,9 @@ class API:
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
self.app.delete("/models/custom/{model_id:path}")(self.delete_custom_model)
self.app.get("/models/search")(self.search_models)
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
@@ -381,10 +386,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
model_card = await ModelCard.load(model_id)
instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []
for sharding in (Sharding.Pipeline, Sharding.Tensor):
for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):
@@ -399,96 +401,93 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
instance=None,
error=str(exc),
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(sorted(placement_node_ids, key=str)):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
return PlacementPreviewResponse(previews=previews)
@@ -628,6 +627,11 @@ class API:
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await collect_chat_response(
@@ -652,23 +656,21 @@ class API:
response = await self._collect_text_generation_with_stats(command.command_id)
return response
async def _resolve_and_validate_text_model(self, model: ModelId) -> ModelId:
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(model)
resolved = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved
instance.shard_assignments.model_id == model_id
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved)
await self._trigger_notify_user_to_download_model(model_id)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {resolved}",
detail=f"No instance found for model {model_id}",
)
return resolved
return model_id
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
@@ -722,6 +724,9 @@ class API:
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
task_params=payload,
@@ -970,6 +975,9 @@ class API:
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
task_params=payload,
@@ -1001,6 +1009,7 @@ class API:
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
advanced_params = _ensure_seed(advanced_params)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
@@ -1179,6 +1188,11 @@ class API:
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await collect_claude_response(
@@ -1206,6 +1220,11 @@ class API:
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await collect_responses_response(
@@ -1236,35 +1255,105 @@ class API:
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id),
family=card.family,
quantization=card.quantization,
base_model=card.base_model,
capabilities=card.capabilities,
)
for card in MODEL_CARDS.values()
for card in await get_model_cards()
]
)
async def add_custom_model(self, payload: AddCustomModelParams) -> ModelListModel:
"""Fetch a model from HuggingFace and save as a custom model card."""
try:
card = await ModelCard.fetch_from_hf(payload.model_id)
except Exception as exc:
raise HTTPException(
status_code=400, detail=f"Failed to fetch model: {exc}"
) from exc
return ModelListModel(
id=card.model_id,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=True,
)
async def delete_custom_model(self, model_id: ModelId) -> JSONResponse:
"""Delete a user-added custom model card."""
deleted = await delete_custom_card(model_id)
if not deleted:
raise HTTPException(status_code=404, detail="Custom model card not found")
return JSONResponse(
{"message": "Model card deleted", "model_id": str(model_id)}
)
async def search_models(
self, query: str = "", limit: int = 20
) -> list[HuggingFaceSearchResult]:
"""Search HuggingFace Hub for mlx-community models."""
from huggingface_hub import list_models
results = list_models(
search=query or None,
author="mlx-community",
sort="downloads",
limit=limit,
)
return [
HuggingFaceSearchResult(
id=m.id,
author=m.author or "",
downloads=m.downloads or 0,
likes=m.likes or 0,
last_modified=str(m.last_modified or ""),
tags=list(m.tags or []),
)
for m in results
]
async def run(self):
shutdown_ev = anyio.Event()
try:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
tg.start_soon(self.run_api, shutdown_ev)
try:
await anyio.sleep_forever()
finally:
with anyio.CancelScope(shield=True):
shutdown_ev.set()
finally:
self.command_sender.close()
self.global_event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"
cfg.bind = [f"0.0.0.0:{self.port}"]
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = None
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
with anyio.CancelScope(shield=True):
await serve(
cast(ASGIFramework, self.app),
cfg,
shutdown_trigger=lambda: anyio.sleep_forever(),
shutdown_trigger=ev.wait,
)
self.command_sender.close()
self.global_event_receiver.close()
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -96,16 +96,18 @@ class Master:
async def run(self):
logger.info("Starting Master")
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
try:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")

View File

@@ -10,6 +10,7 @@ from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
Sharding,
ShardMetadata,
@@ -74,40 +75,43 @@ def allocate_layers_proportionally(
return result
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
):
def _validate_cycle(cycle: Cycle) -> None:
if not cycle.node_ids:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
def _compute_total_memory(
node_ids: list[NodeId],
node_memory: Mapping[NodeId, MemoryUsage],
) -> Memory:
total_memory = sum(
(node_memory[node_id].ram_available for node_id in node_ids),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
if total_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
return total_memory
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
def _allocate_and_validate_layers(
node_ids: list[NodeId],
node_memory: Mapping[NodeId, MemoryUsage],
total_memory: Memory,
model_card: ModelCard,
) -> list[int]:
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
total_layers=model_card.n_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
for node_id in node_ids
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
total_storage_bytes = model_card.storage_size.in_bytes
total_layers = model_card.n_layers
for i, node_id in enumerate(node_ids):
node_layers = layer_allocations[i]
required_memory = (total_storage_bytes * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
@@ -116,32 +120,125 @@ def get_shard_assignments_for_pipeline_parallel(
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
runner_id = RunnerId()
return layer_allocations
shard = PipelineShardMetadata(
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for pipeline parallel execution."""
world_size = len(cycle)
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
if use_cfg_parallel:
return _get_shard_assignments_for_cfg_parallel(model_card, cycle, node_memory)
else:
return _get_shard_assignments_for_pure_pipeline(model_card, cycle, node_memory)
def _get_shard_assignments_for_cfg_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for CFG parallel execution.
CFG parallel runs two independent pipelines. Group 0 processes the positive
prompt, group 1 processes the negative prompt. The ring topology places
group 1's ranks in reverse order so both "last stages" are neighbors for
efficient CFG exchange.
"""
_validate_cycle(cycle)
world_size = len(cycle)
cfg_world_size = 2
pipeline_world_size = world_size // cfg_world_size
# Allocate layers for one pipeline group (both groups run the same layers)
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
pipeline_memory = _compute_total_memory(pipeline_node_ids, node_memory)
layer_allocations = _allocate_and_validate_layers(
pipeline_node_ids, node_memory, pipeline_memory, model_card
)
# Ring topology: group 0 ascending [0,1,2,...], group 1 descending [...,2,1,0]
# This places both last stages as neighbors for CFG exchange.
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
(1, r) for r in reversed(range(pipeline_world_size))
]
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for device_rank, node_id in enumerate(cycle.node_ids):
cfg_rank, pipeline_rank = position_to_cfg_pipeline[device_rank]
layers_before = sum(layer_allocations[:pipeline_rank])
node_layers = layer_allocations[pipeline_rank]
shard = CfgShardMetadata(
model_card=model_card,
device_rank=i,
device_rank=device_rank,
world_size=world_size,
start_layer=layers_assigned,
end_layer=layers_assigned + node_layers,
n_layers=total_layers,
start_layer=layers_before,
end_layer=layers_before + node_layers,
n_layers=model_card.n_layers,
cfg_rank=cfg_rank,
cfg_world_size=cfg_world_size,
pipeline_rank=pipeline_rank,
pipeline_world_size=pipeline_world_size,
)
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
return ShardAssignments(
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
return shard_assignments
def _get_shard_assignments_for_pure_pipeline(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for pure pipeline execution."""
_validate_cycle(cycle)
total_memory = _compute_total_memory(cycle.node_ids, node_memory)
layer_allocations = _allocate_and_validate_layers(
cycle.node_ids, node_memory, total_memory, model_card
)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for pipeline_rank, node_id in enumerate(cycle.node_ids):
layers_before = sum(layer_allocations[:pipeline_rank])
node_layers = layer_allocations[pipeline_rank]
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=pipeline_rank,
world_size=len(cycle),
start_layer=layers_before,
end_layer=layers_before + node_layers,
n_layers=model_card.n_layers,
)
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
return ShardAssignments(
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
def get_shard_assignments_for_tensor_parallel(

View File

@@ -28,7 +28,7 @@ from exo.shared.types.profiling import (
)
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.instances import (
InstanceMeta,
MlxRingInstance,
@@ -136,7 +136,9 @@ async def test_master():
command_id=CommandId(),
task_params=TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input="Hello, how are you?",
input=[
InputMessage(role="user", content="Hello, how are you?")
],
),
)
),
@@ -189,7 +191,7 @@ async def test_master():
assert isinstance(events[2].event.task, TextGenerationTask)
assert events[2].event.task.task_params == TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input="Hello, how are you?",
input=[InputMessage(role="user", content="Hello, how are you?")],
)
await master.shutdown()

View File

@@ -5,6 +5,7 @@ from exo.master.placement_utils import (
filter_cycles_by_memory,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_shard_assignments_for_pipeline_parallel,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
@@ -20,7 +21,11 @@ from exo.shared.types.profiling import (
NodeNetworkInfo,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
Sharding,
)
def test_filter_cycles_by_memory():
@@ -487,3 +492,193 @@ def test_get_shard_assignments_insufficient_memory_raises():
get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory
)
class TestCfgParallelPlacement:
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
topology = Topology()
for node_id in node_ids:
topology.add_node(node_id)
for i, node_id in enumerate(node_ids):
next_node = node_ids[(i + 1) % len(node_ids)]
conn = Connection(
source=node_id,
sink=next_node,
edge=create_socket_connection(i + 1),
)
topology.add_connection(conn)
return topology
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
# CFG models should get CfgShardMetadata
for shard in shards:
assert isinstance(shard, CfgShardMetadata)
# Both nodes should have all layers (no pipeline split)
assert shard.start_layer == 0
assert shard.end_layer == 60
assert shard.cfg_world_size == 2
# Each node is the only stage in its pipeline group
assert shard.pipeline_world_size == 1
assert shard.pipeline_rank == 0
cfg_ranks = sorted(
s.cfg_rank for s in shards if isinstance(s, CfgShardMetadata)
)
assert cfg_ranks == [0, 1]
def test_four_nodes_cfg_model_uses_hybrid(self):
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
nodes = [NodeId() for _ in range(4)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 4]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 4
# CFG models should get CfgShardMetadata
for shard in shards:
assert isinstance(shard, CfgShardMetadata)
assert shard.cfg_world_size == 2
assert shard.pipeline_world_size == 2
assert shard.pipeline_rank in [0, 1]
# Check we have 2 nodes in each CFG group
cfg_0_shards = [
s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 0
]
cfg_1_shards = [
s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 1
]
assert len(cfg_0_shards) == 2
assert len(cfg_1_shards) == 2
# Both CFG groups should have the same layer assignments
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
"""Three nodes (odd) with CFG model should use sequential CFG (PipelineShardMetadata)."""
nodes = [NodeId() for _ in range(3)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 3]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 3
# Odd node count with CFG model falls back to PipelineShardMetadata (sequential CFG)
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
def test_two_nodes_non_cfg_model_uses_pipeline(self):
"""Two nodes with non-CFG model should use pure pipeline (PipelineShardMetadata)."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("flux-test"),
n_layers=57,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=False, # Non-CFG model
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
# Non-CFG models should get PipelineShardMetadata
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
# Should have actual layer sharding (pipeline)
layer_ranges = sorted(
(s.start_layer, s.end_layer)
for s in shards
if isinstance(s, PipelineShardMetadata)
)
# First shard starts at 0, last shard ends at 57
assert layer_ranges[0][0] == 0
assert layer_ranges[-1][1] == 57

View File

@@ -9,6 +9,7 @@ from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
move_on_after,
sleep_forever,
)
from anyio.abc import TaskGroup
@@ -146,18 +147,21 @@ class Router:
async def run(self):
logger.debug("Starting Router")
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
try:
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
finally:
with move_on_after(1, shield=True):
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
async def shutdown(self):
logger.debug("Shutting down Router")
@@ -166,12 +170,12 @@ class Router:
self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
logger.info(f"Subscribed to {topic}")
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
logger.info(f"Unsubscribed from {topic}")
async def _networking_recv(self):
while True:

View File

@@ -2,6 +2,8 @@ import os
import sys
from pathlib import Path
from exo.utils.dashboard_path import find_dashboard, find_resources
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
@@ -31,6 +33,14 @@ EXO_MODELS_DIR = (
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
)
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
RESOURCES_DIR = (
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
)
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
DASHBOARD_DIR = (
find_dashboard() if _DASHBOARD_DIR_ENV is None else Path.home() / _DASHBOARD_DIR_ENV
)
# Log files (data/logs or cache)
EXO_LOG = EXO_CACHE_HOME / "exo.log"
@@ -48,6 +58,8 @@ LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024
EXO_CUSTOM_MODEL_CARDS_DIR = EXO_DATA_HOME / "custom_model_cards"
EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
EXO_TRACING_CACHE_DIR = EXO_CACHE_HOME / "traces"

View File

@@ -86,28 +86,29 @@ class Election:
async def run(self):
logger.info("Starting Election")
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
try:
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election finished")
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
finally:
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election shutdown")
async def elect(self, em: ElectionMessage) -> None:
logger.debug(f"Electing: {em}")

Some files were not shown because too many files have changed in this diff Show More