Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
e4256fa284 fix InstanceViewModel.swift
wasn't caught when we merged the API changes
2026-02-02 17:57:02 +00:00
141 changed files with 2644 additions and 8049 deletions

View File

@@ -142,6 +142,4 @@ jobs:
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
export EXO_DASHBOARD_DIR="$PWD/dashboard/"
export EXO_RESOURCES_DIR="$PWD/resources"
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

1
.gitignore vendored
View File

@@ -31,4 +31,3 @@ dashboard/.svelte-kit/
# host config snapshots
hosts_*.json
.swp

View File

@@ -1,7 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
__all__ = ["Flux1Kontext"]

View File

@@ -1,49 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
from typing import Any
from mlx import nn
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
CLIPEncoder,
)
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.model.flux_vae.vae import VAE
from mflux.utils.generated_image import GeneratedImage
class Flux1Kontext(nn.Module):
vae: VAE
transformer: Transformer
t5_text_encoder: T5Encoder
clip_text_encoder: CLIPEncoder
bits: int | None
lora_paths: list[str] | None
lora_scales: list[float] | None
prompt_cache: dict[str, Any]
tokenizers: dict[str, Any]
def __init__(
self,
quantize: int | None = ...,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
model_config: ModelConfig = ...,
) -> None: ...
def generate_image(
self,
seed: int,
prompt: str,
num_inference_steps: int = ...,
height: int = ...,
width: int = ...,
guidance: float = ...,
image_path: Path | str | None = ...,
image_strength: float | None = ...,
scheduler: str = ...,
) -> GeneratedImage: ...

View File

@@ -1,16 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mflux.models.flux.model.flux_vae.vae import VAE
class KontextUtil:
@staticmethod
def create_image_conditioning_latents(
vae: VAE,
height: int,
width: int,
image_path: str,
) -> tuple[mx.array, mx.array]: ...

View File

@@ -108,7 +108,6 @@ class TokenizerWrapper:
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
eos_token_ids: list[int] | set[int] | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
@@ -118,7 +117,7 @@ class TokenizerWrapper:
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | set[int] | None = ...,
eos_token_ids: list[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,

View File

@@ -14,6 +14,7 @@ import SwiftUI
import UserNotifications
import os.log
@main
struct EXOApp: App {
@StateObject private var controller: ExoProcessController
@StateObject private var stateService: ClusterStateService

View File

@@ -288,61 +288,6 @@ enum NetworkSetupHelper {
"""
}
/// Direct install without GUI (requires root).
/// Returns true on success, false on failure.
static func installDirectly() -> Bool {
let script = makeInstallerScript()
return runShellDirectly(script)
}
/// Direct uninstall without GUI (requires root).
/// Returns true on success, false on failure.
static func uninstallDirectly() -> Bool {
let script = makeUninstallScript()
return runShellDirectly(script)
}
/// Run a shell script directly via Process (no AppleScript, requires root).
/// Returns true on success, false on failure.
private static func runShellDirectly(_ script: String) -> Bool {
let process = Process()
process.executableURL = URL(fileURLWithPath: "/bin/bash")
process.arguments = ["-c", script]
let outputPipe = Pipe()
let errorPipe = Pipe()
process.standardOutput = outputPipe
process.standardError = errorPipe
do {
try process.run()
process.waitUntilExit()
let outputData = outputPipe.fileHandleForReading.readDataToEndOfFile()
let errorData = errorPipe.fileHandleForReading.readDataToEndOfFile()
if let output = String(data: outputData, encoding: .utf8), !output.isEmpty {
print(output)
}
if let errorOutput = String(data: errorData, encoding: .utf8), !errorOutput.isEmpty {
fputs(errorOutput, stderr)
}
if process.terminationStatus == 0 {
logger.info("Shell script completed successfully")
return true
} else {
logger.error("Shell script failed with exit code \(process.terminationStatus)")
return false
}
} catch {
logger.error(
"Failed to run shell script: \(error.localizedDescription, privacy: .public)")
fputs("Error: \(error.localizedDescription)\n", stderr)
return false
}
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript =
script

View File

@@ -1,85 +0,0 @@
//
// main.swift
// EXO
//
// Created by Jake Hillion on 2026-02-03.
//
import Foundation
/// Command line options for the EXO app
enum CLICommand {
case install
case uninstall
case help
case none
}
/// Parse command line arguments to determine the CLI command
func parseArguments() -> CLICommand {
let args = CommandLine.arguments
if args.contains("--help") || args.contains("-h") {
return .help
}
if args.contains("--install") {
return .install
}
if args.contains("--uninstall") {
return .uninstall
}
return .none
}
/// Print usage information
func printUsage() {
let programName = (CommandLine.arguments.first as NSString?)?.lastPathComponent ?? "EXO"
print(
"""
Usage: \(programName) [OPTIONS]
Options:
--install Install EXO network configuration (requires root)
--uninstall Uninstall EXO network configuration (requires root)
--help, -h Show this help message
When run without options, starts the normal GUI application.
Examples:
sudo \(programName) --install Install network components as root
sudo \(programName) --uninstall Remove network components as root
""")
}
/// Check if running as root
func isRunningAsRoot() -> Bool {
return getuid() == 0
}
// Main entry point
let command = parseArguments()
switch command {
case .help:
printUsage()
exit(0)
case .install:
if !isRunningAsRoot() {
fputs("Error: --install requires root privileges. Run with sudo.\n", stderr)
exit(1)
}
let success = NetworkSetupHelper.installDirectly()
exit(success ? 0 : 1)
case .uninstall:
if !isRunningAsRoot() {
fputs("Error: --uninstall requires root privileges. Run with sudo.\n", stderr)
exit(1)
}
let success = NetworkSetupHelper.uninstallDirectly()
exit(success ? 0 : 1)
case .none:
// Start normal GUI application
EXOApp.main()
}

View File

@@ -6,13 +6,11 @@
deleteMessage,
editAndRegenerate,
regenerateLastResponse,
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
interface Props {
class?: string;
@@ -101,23 +99,6 @@
let copiedMessageId = $state<string | null>(null);
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
// Uncertainty heatmap toggle
let heatmapMessageIds = $state<Set<string>>(new Set());
function toggleHeatmap(messageId: string) {
const next = new Set(heatmapMessageIds);
if (next.has(messageId)) {
next.delete(messageId);
} else {
next.add(messageId);
}
heatmapMessageIds = next;
}
function isHeatmapVisible(messageId: string): boolean {
return heatmapMessageIds.has(messageId);
}
function formatTimestamp(timestamp: number): string {
return new Date(timestamp).toLocaleTimeString("en-US", {
hour12: false,
@@ -567,23 +548,13 @@
>
</div>
{:else if message.content || (loading && !message.attachments?.some((a) => a.type === "generated-image"))}
{#if isHeatmapVisible(message.id) && message.tokens && message.tokens.length > 0}
<TokenHeatmap
tokens={message.tokens}
isGenerating={loading &&
isLastAssistantMessage(message.id)}
onRegenerateFrom={(tokenIndex) =>
regenerateFromToken(message.id, tokenIndex)}
/>
{:else}
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>
{/if}
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>
{/if}
{/if}
</div>
@@ -658,35 +629,6 @@
</button>
{/if}
<!-- Uncertainty heatmap toggle (assistant messages with tokens) -->
{#if message.role === "assistant" && message.tokens && message.tokens.length > 0}
<button
onclick={() => toggleHeatmap(message.id)}
class="p-1.5 transition-colors rounded cursor-pointer {isHeatmapVisible(
message.id,
)
? 'text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
title={isHeatmapVisible(message.id)
? "Hide uncertainty heatmap"
: "Show uncertainty heatmap"}
>
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
/>
</svg>
</button>
{/if}
<!-- Regenerate button (last assistant message only) -->
{#if message.role === "assistant" && isLastAssistantMessage(message.id) && !loading}
<button

View File

@@ -1,73 +0,0 @@
<script lang="ts">
type FamilyLogoProps = {
family: string;
class?: string;
};
let { family, class: className = "" }: FamilyLogoProps = $props();
</script>
{#if family === "favorites"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{:else if family === "llama" || family === "meta"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M6.915 4.03c-1.968 0-3.683 1.28-4.871 3.113C.704 9.208 0 11.883 0 14.449c0 .706.07 1.369.21 1.973a6.624 6.624 0 0 0 .265.86 5.297 5.297 0 0 0 .371.761c.696 1.159 1.818 1.927 3.593 1.927 1.497 0 2.633-.671 3.965-2.444.76-1.012 1.144-1.626 2.663-4.32l.756-1.339.186-.325c.061.1.121.196.183.3l2.152 3.595c.724 1.21 1.665 2.556 2.47 3.314 1.046.987 1.992 1.22 3.06 1.22 1.075 0 1.876-.355 2.455-.843a3.743 3.743 0 0 0 .81-.973c.542-.939.861-2.127.861-3.745 0-2.72-.681-5.357-2.084-7.45-1.282-1.912-2.957-2.93-4.716-2.93-1.047 0-2.088.467-3.053 1.308-.652.57-1.257 1.29-1.82 2.05-.69-.875-1.335-1.547-1.958-2.056-1.182-.966-2.315-1.303-3.454-1.303zm10.16 2.053c1.147 0 2.188.758 2.992 1.999 1.132 1.748 1.647 4.195 1.647 6.4 0 1.548-.368 2.9-1.839 2.9-.58 0-1.027-.23-1.664-1.004-.496-.601-1.343-1.878-2.832-4.358l-.617-1.028a44.908 44.908 0 0 0-1.255-1.98c.07-.109.141-.224.211-.327 1.12-1.667 2.118-2.602 3.358-2.602zm-10.201.553c1.265 0 2.058.791 2.675 1.446.307.327.737.871 1.234 1.579l-1.02 1.566c-.757 1.163-1.882 3.017-2.837 4.338-1.191 1.649-1.81 1.817-2.486 1.817-.524 0-1.038-.237-1.383-.794-.263-.426-.464-1.13-.464-2.046 0-2.221.63-4.535 1.66-6.088.454-.687.964-1.226 1.533-1.533a2.264 2.264 0 0 1 1.088-.285z"
/>
</svg>
{:else if family === "qwen"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12.604 1.34c.393.69.784 1.382 1.174 2.075a.18.18 0 00.157.091h5.552c.174 0 .322.11.446.327l1.454 2.57c.19.337.24.478.024.837-.26.43-.513.864-.76 1.3l-.367.658c-.106.196-.223.28-.04.512l2.652 4.637c.172.301.111.494-.043.77-.437.785-.882 1.564-1.335 2.34-.159.272-.352.375-.68.37-.777-.016-1.552-.01-2.327.016a.099.099 0 00-.081.05 575.097 575.097 0 01-2.705 4.74c-.169.293-.38.363-.725.364-.997.003-2.002.004-3.017.002a.537.537 0 01-.465-.271l-1.335-2.323a.09.09 0 00-.083-.049H4.982c-.285.03-.553-.001-.805-.092l-1.603-2.77a.543.543 0 01-.002-.54l1.207-2.12a.198.198 0 000-.197 550.951 550.951 0 01-1.875-3.272l-.79-1.395c-.16-.31-.173-.496.095-.965.465-.813.927-1.625 1.387-2.436.132-.234.304-.334.584-.335a338.3 338.3 0 012.589-.001.124.124 0 00.107-.063l2.806-4.895a.488.488 0 01.422-.246c.524-.001 1.053 0 1.583-.006L11.704 1c.341-.003.724.032.9.34zm-3.432.403a.06.06 0 00-.052.03L6.254 6.788a.157.157 0 01-.135.078H3.253c-.056 0-.07.025-.041.074l5.81 10.156c.025.042.013.062-.034.063l-2.795.015a.218.218 0 00-.2.116l-1.32 2.31c-.044.078-.021.118.068.118l5.716.008c.046 0 .08.02.104.061l1.403 2.454c.046.081.092.082.139 0l5.006-8.76.783-1.382a.055.055 0 01.096 0l1.424 2.53a.122.122 0 00.107.062l2.763-.02a.04.04 0 00.035-.02.041.041 0 000-.04l-2.9-5.086a.108.108 0 010-.113l.293-.507 1.12-1.977c.024-.041.012-.062-.035-.062H9.2c-.059 0-.073-.026-.043-.077l1.434-2.505a.107.107 0 000-.114L9.225 1.774a.06.06 0 00-.053-.031zm6.29 8.02c.046 0 .058.02.034.06l-.832 1.465-2.613 4.585a.056.056 0 01-.05.029.058.058 0 01-.05-.029L8.498 9.841c-.02-.034-.01-.052.028-.054l.216-.012 6.722-.012z"
/>
</svg>
{:else if family === "deepseek"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M23.748 4.482c-.254-.124-.364.113-.512.234-.051.039-.094.09-.137.136-.372.397-.806.657-1.373.626-.829-.046-1.537.214-2.163.848-.133-.782-.575-1.248-1.247-1.548-.352-.156-.708-.311-.955-.65-.172-.241-.219-.51-.305-.774-.055-.16-.11-.323-.293-.35-.2-.031-.278.136-.356.276-.313.572-.434 1.202-.422 1.84.027 1.436.633 2.58 1.838 3.393.137.093.172.187.129.323-.082.28-.18.552-.266.833-.055.179-.137.217-.329.14a5.526 5.526 0 01-1.736-1.18c-.857-.828-1.631-1.742-2.597-2.458a11.365 11.365 0 00-.689-.471c-.985-.957.13-1.743.388-1.836.27-.098.093-.432-.779-.428-.872.004-1.67.295-2.687.684a3.055 3.055 0 01-.465.137 9.597 9.597 0 00-2.883-.102c-1.885.21-3.39 1.102-4.497 2.623C.082 8.606-.231 10.684.152 12.85c.403 2.284 1.569 4.175 3.36 5.653 1.858 1.533 3.997 2.284 6.438 2.14 1.482-.085 3.133-.284 4.994-1.86.47.234.962.327 1.78.397.63.059 1.236-.03 1.705-.128.735-.156.684-.837.419-.961-2.155-1.004-1.682-.595-2.113-.926 1.096-1.296 2.746-2.642 3.392-7.003.05-.347.007-.565 0-.845-.004-.17.035-.237.23-.256a4.173 4.173 0 001.545-.475c1.396-.763 1.96-2.015 2.093-3.517.02-.23-.004-.467-.247-.588zM11.581 18c-2.089-1.642-3.102-2.183-3.52-2.16-.392.024-.321.471-.235.763.09.288.207.486.371.739.114.167.192.416-.113.603-.673.416-1.842-.14-1.897-.167-1.361-.802-2.5-1.86-3.301-3.307-.774-1.393-1.224-2.887-1.298-4.482-.02-.386.093-.522.477-.592a4.696 4.696 0 011.529-.039c2.132.312 3.946 1.265 5.468 2.774.868.86 1.525 1.887 2.202 2.891.72 1.066 1.494 2.082 2.48 2.914.348.292.625.514.891.677-.802.09-2.14.11-3.054-.614zm1-6.44a.306.306 0 01.415-.287.302.302 0 01.2.288.306.306 0 01-.31.307.303.303 0 01-.304-.308zm3.11 1.596c-.2.081-.399.151-.59.16a1.245 1.245 0 01-.798-.254c-.274-.23-.47-.358-.552-.758a1.73 1.73 0 01.016-.588c.07-.327-.008-.537-.239-.727-.187-.156-.426-.199-.688-.199a.559.559 0 01-.254-.078c-.11-.054-.2-.19-.114-.358.028-.054.16-.186.192-.21.356-.202.767-.136 1.146.016.352.144.618.408 1.001.782.391.451.462.576.685.914.176.265.336.537.445.848.067.195-.019.354-.25.452z"
/>
</svg>
{:else if family === "openai" || family === "gpt-oss"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M22.2819 9.8211a5.9847 5.9847 0 0 0-.5157-4.9108 6.0462 6.0462 0 0 0-6.5098-2.9A6.0651 6.0651 0 0 0 4.9807 4.1818a5.9847 5.9847 0 0 0-3.9977 2.9 6.0462 6.0462 0 0 0 .7427 7.0966 5.98 5.98 0 0 0 .511 4.9107 6.051 6.051 0 0 0 6.5146 2.9001A5.9847 5.9847 0 0 0 13.2599 24a6.0557 6.0557 0 0 0 5.7718-4.2058 5.9894 5.9894 0 0 0 3.9977-2.9001 6.0557 6.0557 0 0 0-.7475-7.0729zm-9.022 12.6081a4.4755 4.4755 0 0 1-2.8764-1.0408l.1419-.0804 4.7783-2.7582a.7948.7948 0 0 0 .3927-.6813v-6.7369l2.02 1.1686a.071.071 0 0 1 .038.052v5.5826a4.504 4.504 0 0 1-4.4945 4.4944zm-9.6607-4.1254a4.4708 4.4708 0 0 1-.5346-3.0137l.142.0852 4.783 2.7582a.7712.7712 0 0 0 .7806 0l5.8428-3.3685v2.3324a.0804.0804 0 0 1-.0332.0615L9.74 19.9502a4.4992 4.4992 0 0 1-6.1408-1.6464zM2.3408 7.8956a4.485 4.485 0 0 1 2.3655-1.9728V11.6a.7664.7664 0 0 0 .3879.6765l5.8144 3.3543-2.0201 1.1685a.0757.0757 0 0 1-.071 0l-4.8303-2.7865A4.504 4.504 0 0 1 2.3408 7.872zm16.5963 3.8558L13.1038 8.364 15.1192 7.2a.0757.0757 0 0 1 .071 0l4.8303 2.7913a4.4944 4.4944 0 0 1-.6765 8.1042v-5.6772a.79.79 0 0 0-.407-.667zm2.0107-3.0231l-.142-.0852-4.7735-2.7818a.7759.7759 0 0 0-.7854 0L9.409 9.2297V6.8974a.0662.0662 0 0 1 .0284-.0615l4.8303-2.7866a4.4992 4.4992 0 0 1 6.6802 4.66zM8.3065 12.863l-2.02-1.1638a.0804.0804 0 0 1-.038-.0567V6.0742a4.4992 4.4992 0 0 1 7.3757-3.4537l-.142.0805L8.704 5.459a.7948.7948 0 0 0-.3927.6813zm1.0976-2.3654l2.602-1.4998 2.6069 1.4998v2.9994l-2.5974 1.4997-2.6067-1.4997Z"
/>
</svg>
{:else if family === "glm"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M11.991 23.503a.24.24 0 00-.244.248.24.24 0 00.244.249.24.24 0 00.245-.249.24.24 0 00-.22-.247l-.025-.001zM9.671 5.365a1.697 1.697 0 011.099 2.132l-.071.172-.016.04-.018.054c-.07.16-.104.32-.104.498-.035.71.47 1.279 1.186 1.314h.366c1.309.053 2.338 1.173 2.286 2.523-.052 1.332-1.152 2.38-2.478 2.327h-.174c-.715.018-1.274.64-1.239 1.368 0 .124.018.23.053.337.209.373.54.658.96.8.75.23 1.517-.125 1.9-.782l.018-.035c.402-.64 1.17-.96 1.92-.711.854.284 1.378 1.226 1.099 2.167a1.661 1.661 0 01-2.077 1.102 1.711 1.711 0 01-.907-.711l-.017-.035c-.2-.323-.463-.58-.851-.711l-.056-.018a1.646 1.646 0 00-1.954.746 1.66 1.66 0 01-1.065.764 1.677 1.677 0 01-1.989-1.279c-.209-.906.332-1.83 1.257-2.043a1.51 1.51 0 01.296-.035h.018c.68-.071 1.151-.622 1.116-1.333a1.307 1.307 0 00-.227-.693 2.515 2.515 0 01-.366-1.403 2.39 2.39 0 01.366-1.208c.14-.195.21-.444.227-.693.018-.71-.506-1.261-1.186-1.332l-.07-.018a1.43 1.43 0 01-.299-.07l-.05-.019a1.7 1.7 0 01-1.047-2.114 1.68 1.68 0 012.094-1.101zm-5.575 10.11c.26-.264.639-.367.994-.27.355.096.633.379.728.74.095.362-.007.748-.267 1.013-.402.41-1.053.41-1.455 0a1.062 1.062 0 010-1.482zm14.845-.294c.359-.09.738.024.992.297.254.274.344.665.237 1.025-.107.36-.396.634-.756.718-.551.128-1.1-.22-1.23-.781a1.05 1.05 0 01.757-1.26zm-.064-4.39c.314.32.49.753.49 1.206 0 .452-.176.886-.49 1.206-.315.32-.74.5-1.185.5-.444 0-.87-.18-1.184-.5a1.727 1.727 0 010-2.412 1.654 1.654 0 012.369 0zm-11.243.163c.364.484.447 1.128.218 1.691a1.665 1.665 0 01-2.188.923c-.855-.36-1.26-1.358-.907-2.228a1.68 1.68 0 011.33-1.038c.593-.08 1.183.169 1.547.652zm11.545-4.221c.368 0 .708.2.892.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.892.524c-.568 0-1.03-.47-1.03-1.048 0-.579.462-1.048 1.03-1.048zm-14.358 0c.368 0 .707.2.891.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.891.524c-.569 0-1.03-.47-1.03-1.048 0-.579.461-1.048 1.03-1.048zm10.031-1.475c.925 0 1.675.764 1.675 1.706s-.75 1.705-1.675 1.705-1.674-.763-1.674-1.705c0-.942.75-1.706 1.674-1.706zm-2.626-.684c.362-.082.653-.356.761-.718a1.062 1.062 0 00-.238-1.028 1.017 1.017 0 00-.996-.294c-.547.14-.881.7-.752 1.257.13.558.675.907 1.225.783zm0 16.876c.359-.087.644-.36.75-.72a1.062 1.062 0 00-.237-1.019 1.018 1.018 0 00-.985-.301 1.037 1.037 0 00-.762.717c-.108.361-.017.754.239 1.028.245.263.606.377.953.305l.043-.01zM17.19 3.5a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64a.631.631 0 00-.628.64c0 .355.28.64.628.64zm-10.38 0a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64a.631.631 0 00-.628.64c0 .355.279.64.628.64zm-5.182 7.852a.631.631 0 00-.628.64c0 .354.28.639.628.639a.63.63 0 00.627-.606l.001-.034a.62.62 0 00-.628-.64zm5.182 9.13a.631.631 0 00-.628.64c0 .355.279.64.628.64a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm10.38.018a.631.631 0 00-.628.64c0 .355.28.64.628.64a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64zm5.182-9.148a.631.631 0 00-.628.64c0 .354.279.639.628.639a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm-.384-4.992a.24.24 0 00.244-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249c0 .142.122.249.244.249zM11.991.497a.24.24 0 00.245-.248A.24.24 0 0011.99 0a.24.24 0 00-.244.249c0 .133.108.236.223.247l.021.001zM2.011 6.36a.24.24 0 00.245-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249.24.24 0 00.244.249zm0 11.263a.24.24 0 00-.243.248.24.24 0 00.244.249.24.24 0 00.244-.249.252.252 0 00-.244-.248zm19.995-.018a.24.24 0 00-.245.248.24.24 0 00.245.25.24.24 0 00.244-.25.252.252 0 00-.244-.248z"
/>
</svg>
{:else if family === "minimax"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M16.278 2c1.156 0 2.093.927 2.093 2.07v12.501a.74.74 0 00.744.709.74.74 0 00.743-.709V9.099a2.06 2.06 0 012.071-2.049A2.06 2.06 0 0124 9.1v6.561a.649.649 0 01-.652.645.649.649 0 01-.653-.645V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v7.472a2.037 2.037 0 01-2.048 2.026 2.037 2.037 0 01-2.048-2.026v-12.5a.785.785 0 00-.788-.753.785.785 0 00-.789.752l-.001 15.904A2.037 2.037 0 0113.441 22a2.037 2.037 0 01-2.048-2.026V18.04c0-.356.292-.645.652-.645.36 0 .652.289.652.645v1.934c0 .263.142.506.372.638.23.131.514.131.744 0a.734.734 0 00.372-.638V4.07c0-1.143.937-2.07 2.093-2.07zm-5.674 0c1.156 0 2.093.927 2.093 2.07v11.523a.648.648 0 01-.652.645.648.648 0 01-.652-.645V4.07a.785.785 0 00-.789-.78.785.785 0 00-.789.78v14.013a2.06 2.06 0 01-2.07 2.048 2.06 2.06 0 01-2.071-2.048V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v3.8a2.06 2.06 0 01-2.071 2.049A2.06 2.06 0 010 12.9v-1.378c0-.357.292-.646.652-.646.36 0 .653.29.653.646V12.9c0 .418.343.757.766.757s.766-.339.766-.757V9.099a2.06 2.06 0 012.07-2.048 2.06 2.06 0 012.071 2.048v8.984c0 .419.343.758.767.758.423 0 .766-.339.766-.758V4.07c0-1.143.937-2.07 2.093-2.07z"
/>
</svg>
{:else if family === "kimi"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19.738 5.776c.163-.209.306-.4.457-.585.07-.087.064-.153-.004-.244-.655-.861-.717-1.817-.34-2.787.283-.73.909-1.072 1.674-1.145.477-.045.945.004 1.379.236.57.305.902.77 1.01 1.412.086.512.07 1.012-.075 1.508-.257.878-.888 1.333-1.753 1.448-.718.096-1.446.108-2.17.157-.056.004-.113 0-.178 0z"
/>
<path
d="M17.962 1.844h-4.326l-3.425 7.81H5.369V1.878H1.5V22h3.87v-8.477h6.824a3.025 3.025 0 002.743-1.75V22h3.87v-8.477a3.87 3.87 0 00-3.588-3.86v-.01h-2.125a3.94 3.94 0 002.323-2.12l2.545-5.689z"
/>
</svg>
{:else if family === "huggingface"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12.025 1.13c-5.77 0-10.449 4.647-10.449 10.378 0 1.112.178 2.181.503 3.185.064-.222.203-.444.416-.577a.96.96 0 0 1 .524-.15c.293 0 .584.124.84.284.278.173.48.408.71.694.226.282.458.611.684.951v-.014c.017-.324.106-.622.264-.874s.403-.487.762-.543c.3-.047.596.06.787.203s.31.313.4.467c.15.257.212.468.233.542.01.026.653 1.552 1.657 2.54.616.605 1.01 1.223 1.082 1.912.055.537-.096 1.059-.38 1.572.637.121 1.294.187 1.967.187.657 0 1.298-.063 1.921-.178-.287-.517-.44-1.041-.384-1.581.07-.69.465-1.307 1.081-1.913 1.004-.987 1.647-2.513 1.657-2.539.021-.074.083-.285.233-.542.09-.154.208-.323.4-.467a1.08 1.08 0 0 1 .787-.203c.359.056.604.29.762.543s.247.55.265.874v.015c.225-.34.457-.67.683-.952.23-.286.432-.52.71-.694.257-.16.547-.284.84-.285a.97.97 0 0 1 .524.151c.228.143.373.388.43.625l.006.04a10.3 10.3 0 0 0 .534-3.273c0-5.731-4.678-10.378-10.449-10.378M8.327 6.583a1.5 1.5 0 0 1 .713.174 1.487 1.487 0 0 1 .617 2.013c-.183.343-.762-.214-1.102-.094-.38.134-.532.914-.917.71a1.487 1.487 0 0 1 .69-2.803m7.486 0a1.487 1.487 0 0 1 .689 2.803c-.385.204-.536-.576-.916-.71-.34-.12-.92.437-1.103.094a1.487 1.487 0 0 1 .617-2.013 1.5 1.5 0 0 1 .713-.174m-10.68 1.55a.96.96 0 1 1 0 1.921.96.96 0 0 1 0-1.92m13.838 0a.96.96 0 1 1 0 1.92.96.96 0 0 1 0-1.92M8.489 11.458c.588.01 1.965 1.157 3.572 1.164 1.607-.007 2.984-1.155 3.572-1.164.196-.003.305.12.305.454 0 .886-.424 2.328-1.563 3.202-.22-.756-1.396-1.366-1.63-1.32q-.011.001-.02.006l-.044.026-.01.008-.03.024q-.018.017-.035.036l-.032.04a1 1 0 0 0-.058.09l-.014.025q-.049.088-.11.19a1 1 0 0 1-.083.116 1.2 1.2 0 0 1-.173.18q-.035.029-.075.058a1.3 1.3 0 0 1-.251-.243 1 1 0 0 1-.076-.107c-.124-.193-.177-.363-.337-.444-.034-.016-.104-.008-.2.022q-.094.03-.216.087-.06.028-.125.063l-.13.074q-.067.04-.136.086a3 3 0 0 0-.135.096 3 3 0 0 0-.26.219 2 2 0 0 0-.12.121 2 2 0 0 0-.106.128l-.002.002a2 2 0 0 0-.09.132l-.001.001a1.2 1.2 0 0 0-.105.212q-.013.036-.024.073c-1.139-.875-1.563-2.317-1.563-3.203 0-.334.109-.457.305-.454m.836 10.354c.824-1.19.766-2.082-.365-3.194-1.13-1.112-1.789-2.738-1.789-2.738s-.246-.945-.806-.858-.97 1.499.202 2.362c1.173.864-.233 1.45-.685.64-.45-.812-1.683-2.896-2.322-3.295s-1.089-.175-.938.647 2.822 2.813 2.562 3.244-1.176-.506-1.176-.506-2.866-2.567-3.49-1.898.473 1.23 2.037 2.16c1.564.932 1.686 1.178 1.464 1.53s-3.675-2.511-4-1.297c-.323 1.214 3.524 1.567 3.287 2.405-.238.839-2.71-1.587-3.216-.642-.506.946 3.49 2.056 3.522 2.064 1.29.33 4.568 1.028 5.713-.624m5.349 0c-.824-1.19-.766-2.082.365-3.194 1.13-1.112 1.789-2.738 1.789-2.738s.246-.945.806-.858.97 1.499-.202 2.362c-1.173.864.233 1.45.685.64.451-.812 1.683-2.896 2.322-3.295s1.089-.175.938.647-2.822 2.813-2.562 3.244 1.176-.506 1.176-.506 2.866-2.567 3.49-1.898-.473 1.23-2.037 2.16c-1.564.932-1.686 1.178-1.464 1.53s3.675-2.511 4-1.297c.323 1.214-3.524 1.567-3.287 2.405.238.839 2.71-1.587 3.216-.642.506.946-3.49 2.056-3.522 2.064-1.29.33-4.568 1.028-5.713-.624"
/>
</svg>
{:else}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z"
/>
</svg>
{/if}

View File

@@ -1,142 +0,0 @@
<script lang="ts">
import FamilyLogos from "./FamilyLogos.svelte";
type FamilySidebarProps = {
families: string[];
selectedFamily: string | null;
hasFavorites: boolean;
onSelect: (family: string | null) => void;
};
let { families, selectedFamily, hasFavorites, onSelect }: FamilySidebarProps =
$props();
// Family display names
const familyNames: Record<string, string> = {
favorites: "Favorites",
huggingface: "Hub",
llama: "Meta",
qwen: "Qwen",
deepseek: "DeepSeek",
"gpt-oss": "OpenAI",
glm: "GLM",
minimax: "MiniMax",
kimi: "Kimi",
};
function getFamilyName(family: string): string {
return (
familyNames[family] || family.charAt(0).toUpperCase() + family.slice(1)
);
}
</script>
<div
class="flex flex-col gap-1 py-2 px-1 border-r border-exo-yellow/10 bg-exo-medium-gray/30 min-w-[64px]"
>
<!-- All models (no filter) -->
<button
type="button"
onclick={() => onSelect(null)}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
null
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="All models"
>
<svg
class="w-5 h-5 {selectedFamily === null
? 'text-exo-yellow'
: 'text-white/50 group-hover:text-white/70'}"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M4 8h4V4H4v4zm6 12h4v-4h-4v4zm-6 0h4v-4H4v4zm0-6h4v-4H4v4zm6 0h4v-4h-4v4zm6-10v4h4V4h-4zm-6 4h4V4h-4v4zm6 6h4v-4h-4v4zm0 6h4v-4h-4v4z"
/>
</svg>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === null
? 'text-exo-yellow'
: 'text-white/40 group-hover:text-white/60'}">All</span
>
</button>
<!-- Favorites (only show if has favorites) -->
{#if hasFavorites}
<button
type="button"
onclick={() => onSelect("favorites")}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
'favorites'
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="Show favorited models"
>
<FamilyLogos
family="favorites"
class={selectedFamily === "favorites"
? "text-amber-400"
: "text-white/50 group-hover:text-amber-400/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'favorites'
? 'text-amber-400'
: 'text-white/40 group-hover:text-white/60'}">Faves</span
>
</button>
{/if}
<!-- HuggingFace Hub -->
<button
type="button"
onclick={() => onSelect("huggingface")}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
'huggingface'
? 'bg-orange-500/20 border-l-2 border-orange-400'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="Browse and add models from Hugging Face"
>
<FamilyLogos
family="huggingface"
class={selectedFamily === "huggingface"
? "text-orange-400"
: "text-white/50 group-hover:text-orange-400/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'huggingface'
? 'text-orange-400'
: 'text-white/40 group-hover:text-white/60'}">Hub</span
>
</button>
<div class="h-px bg-exo-yellow/10 my-1"></div>
<!-- Model families -->
{#each families as family}
<button
type="button"
onclick={() => onSelect(family)}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
family
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title={getFamilyName(family)}
>
<FamilyLogos
{family}
class={selectedFamily === family
? "text-exo-yellow"
: "text-white/50 group-hover:text-white/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 truncate max-w-full {selectedFamily ===
family
? 'text-exo-yellow'
: 'text-white/40 group-hover:text-white/60'}"
>
{getFamilyName(family)}
</span>
</button>
{/each}
</div>

View File

@@ -1,151 +0,0 @@
<script lang="ts">
interface HuggingFaceModel {
id: string;
author: string;
downloads: number;
likes: number;
last_modified: string;
tags: string[];
}
type HuggingFaceResultItemProps = {
model: HuggingFaceModel;
isAdded: boolean;
isAdding: boolean;
onAdd: () => void;
onSelect: () => void;
downloadedOnNodes?: string[];
};
let {
model,
isAdded,
isAdding,
onAdd,
onSelect,
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
if (num >= 1000000) {
return `${(num / 1000000).toFixed(1)}M`;
} else if (num >= 1000) {
return `${(num / 1000).toFixed(1)}k`;
}
return num.toString();
}
// Extract model name from full ID (e.g., "mlx-community/Llama-3.2-1B" -> "Llama-3.2-1B")
const modelName = $derived(model.id.split("/").pop() || model.id);
</script>
<div
class="flex items-center justify-between gap-3 px-3 py-2.5 hover:bg-white/5 transition-colors border-b border-white/5 last:border-b-0"
>
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="text-sm font-mono text-white truncate" title={model.id}
>{modelName}</span
>
{#if downloadedOnNodes.length > 0}
<span
class="flex-shrink-0"
title={`Downloaded on ${downloadedOnNodes.join(", ")}`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{#if isAdded}
<span
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"
>Added</span
>
{/if}
</div>
<div class="flex items-center gap-3 mt-0.5 text-xs text-white/40">
<span class="truncate">{model.author}</span>
<span
class="flex items-center gap-1 shrink-0"
title="Downloads in the last 30 days"
>
<svg
class="w-3 h-3"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"
/>
</svg>
{formatNumber(model.downloads)}
</span>
<span
class="flex items-center gap-1 shrink-0"
title="Community likes on Hugging Face"
>
<svg
class="w-3 h-3"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4.318 6.318a4.5 4.5 0 000 6.364L12 20.364l7.682-7.682a4.5 4.5 0 00-6.364-6.364L12 7.636l-1.318-1.318a4.5 4.5 0 00-6.364 0z"
/>
</svg>
{formatNumber(model.likes)}
</span>
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
{#if isAdded}
<button
type="button"
onclick={onSelect}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-yellow/10 text-exo-yellow border border-exo-yellow/30 hover:bg-exo-yellow/20 transition-colors rounded cursor-pointer"
>
Select
</button>
{:else}
<button
type="button"
onclick={onAdd}
disabled={isAdding}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded cursor-pointer disabled:opacity-50 disabled:cursor-not-allowed"
>
{#if isAdding}
<span class="flex items-center gap-1.5">
<span
class="w-3 h-3 border-2 border-orange-400 border-t-transparent rounded-full animate-spin"
></span>
Adding...
</span>
{:else}
+ Add
{/if}
</button>
{/if}
</div>
</div>

View File

@@ -1,213 +0,0 @@
<script lang="ts">
import { fly } from "svelte/transition";
import { cubicOut } from "svelte/easing";
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
type ModelFilterPopoverProps = {
filters: FilterState;
onChange: (filters: FilterState) => void;
onClear: () => void;
onClose: () => void;
};
let { filters, onChange, onClear, onClose }: ModelFilterPopoverProps =
$props();
// Available capabilities
const availableCapabilities = [
{ id: "text", label: "Text" },
{ id: "thinking", label: "Thinking" },
{ id: "code", label: "Code" },
{ id: "vision", label: "Vision" },
];
// Size ranges
const sizeRanges = [
{ label: "< 10GB", min: 0, max: 10 },
{ label: "10-50GB", min: 10, max: 50 },
{ label: "50-200GB", min: 50, max: 200 },
{ label: "> 200GB", min: 200, max: 10000 },
];
function toggleCapability(cap: string) {
const next = filters.capabilities.includes(cap)
? filters.capabilities.filter((c) => c !== cap)
: [...filters.capabilities, cap];
onChange({ ...filters, capabilities: next });
}
function selectSizeRange(range: { min: number; max: number } | null) {
// Toggle off if same range is clicked
if (
filters.sizeRange &&
range &&
filters.sizeRange.min === range.min &&
filters.sizeRange.max === range.max
) {
onChange({ ...filters, sizeRange: null });
} else {
onChange({ ...filters, sizeRange: range });
}
}
function handleClickOutside(e: MouseEvent) {
const target = e.target as HTMLElement;
if (
!target.closest(".filter-popover") &&
!target.closest(".filter-toggle")
) {
onClose();
}
}
</script>
<svelte:window onclick={handleClickOutside} />
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="filter-popover absolute right-0 top-full mt-2 w-64 bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-xl z-10"
transition:fly={{ y: -10, duration: 200, easing: cubicOut }}
onclick={(e) => e.stopPropagation()}
role="dialog"
aria-label="Filter options"
>
<div class="p-3 space-y-4">
<!-- Capabilities -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Capabilities</h4>
<div class="flex flex-wrap gap-1.5">
{#each availableCapabilities as cap}
{@const isSelected = filters.capabilities.includes(cap.id)}
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {isSelected
? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() => toggleCapability(cap.id)}
>
{#if cap.id === "text"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "thinking"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "code"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M16 18l6-6-6-6M8 6l-6 6 6 6"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "vision"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z"
stroke-linecap="round"
stroke-linejoin="round"
/><circle cx="12" cy="12" r="3" /></svg
>
{/if}
<span class="ml-1">{cap.label}</span>
</button>
{/each}
</div>
</div>
<!-- Downloaded only -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Availability</h4>
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly
? 'bg-green-500/20 text-green-400 border border-green-500/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() =>
onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}
>
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="ml-1">Downloaded</span>
</button>
</div>
<!-- Size range -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>
<div class="flex flex-wrap gap-1.5">
{#each sizeRanges as range}
{@const isSelected =
filters.sizeRange &&
filters.sizeRange.min === range.min &&
filters.sizeRange.max === range.max}
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {isSelected
? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() => selectSizeRange(range)}
>
{range.label}
</button>
{/each}
</div>
</div>
<!-- Clear button -->
<button
type="button"
class="w-full py-1.5 text-xs font-mono text-white/50 hover:text-white/70 hover:bg-white/5 rounded transition-colors"
onclick={onClear}
>
Clear all filters
</button>
</div>
</div>

View File

@@ -1,401 +0,0 @@
<script lang="ts">
interface ModelInfo {
id: string;
name?: string;
storage_size_megabytes?: number;
base_model?: string;
quantization?: string;
supports_tensor?: boolean;
capabilities?: string[];
family?: string;
is_custom?: boolean;
}
interface ModelGroup {
id: string;
name: string;
capabilities: string[];
family: string;
variants: ModelInfo[];
smallestVariant: ModelInfo;
hasMultipleVariants: boolean;
}
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
type ModelPickerGroupProps = {
group: ModelGroup;
isExpanded: boolean;
isFavorite: boolean;
selectedModelId: string | null;
canModelFit: (id: string) => boolean;
onToggleExpand: () => void;
onSelectModel: (modelId: string) => void;
onToggleFavorite: (baseModelId: string) => void;
onShowInfo: (group: ModelGroup) => void;
downloadStatusMap?: Map<string, DownloadAvailability>;
};
let {
group,
isExpanded,
isFavorite,
selectedModelId,
canModelFit,
onToggleExpand,
onSelectModel,
onToggleFavorite,
onShowInfo,
downloadStatusMap,
}: ModelPickerGroupProps = $props();
// Group-level download status: show if any variant is downloaded
const groupDownloadStatus = $derived.by(() => {
if (!downloadStatusMap || downloadStatusMap.size === 0) return undefined;
// Return the first available entry (prefer "available" ones)
for (const avail of downloadStatusMap.values()) {
if (avail.available) return avail;
}
return downloadStatusMap.values().next().value;
});
// Format storage size
function formatSize(mb: number | undefined): string {
if (!mb) return "";
if (mb >= 1024) {
return `${(mb / 1024).toFixed(0)}GB`;
}
return `${mb}MB`;
}
// Check if any variant can fit
const anyVariantFits = $derived(
group.variants.some((v) => canModelFit(v.id)),
);
// Check if this group's model is currently selected (for single-variant groups)
const isMainSelected = $derived(
!group.hasMultipleVariants &&
group.variants.some((v) => v.id === selectedModelId),
);
</script>
<div
class="border-b border-white/5 last:border-b-0 {!anyVariantFits
? 'opacity-50'
: ''}"
>
<!-- Main row -->
<div
class="flex items-center gap-2 px-3 py-2.5 transition-colors {anyVariantFits
? 'hover:bg-white/5 cursor-pointer'
: 'cursor-not-allowed'} {isMainSelected
? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'
: 'border-l-2 border-transparent'}"
onclick={() => {
if (group.hasMultipleVariants) {
onToggleExpand();
} else {
const modelId = group.variants[0]?.id;
if (modelId && canModelFit(modelId)) {
onSelectModel(modelId);
}
}
}}
role="button"
tabindex="0"
onkeydown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
if (group.hasMultipleVariants) {
onToggleExpand();
} else {
const modelId = group.variants[0]?.id;
if (modelId && canModelFit(modelId)) {
onSelectModel(modelId);
}
}
}
}}
>
<!-- Expand/collapse chevron (for groups with variants) -->
{#if group.hasMultipleVariants}
<svg
class="w-4 h-4 text-white/40 transition-transform duration-200 flex-shrink-0 {isExpanded
? 'rotate-90'
: ''}"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M8.59 16.59L13.17 12 8.59 7.41 10 6l6 6-6 6-1.41-1.41z" />
</svg>
{:else}
<div class="w-4 flex-shrink-0"></div>
{/if}
<!-- Model name -->
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="font-mono text-sm text-white truncate">
{group.name}
</span>
<!-- Capability icons -->
{#each group.capabilities.filter((c) => c !== "text") as cap}
{#if cap === "thinking"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports Thinking"
>
<path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
{:else if cap === "code"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports code generation"
>
<path
d="M16 18l6-6-6-6M8 6l-6 6 6 6"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
{:else if cap === "vision"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports image input"
>
<path
d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z"
stroke-linecap="round"
stroke-linejoin="round"
/>
<circle cx="12" cy="12" r="3" />
</svg>
{:else if cap === "image_gen"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports image generation"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<path d="M21 15l-5-5L5 21" />
</svg>
{/if}
{/each}
</div>
</div>
<!-- Size indicator (smallest variant) -->
{#if !group.hasMultipleVariants && group.smallestVariant?.storage_size_megabytes}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{formatSize(group.smallestVariant.storage_size_megabytes)}
</span>
{/if}
<!-- Variant count with size range -->
{#if group.hasMultipleVariants}
{@const sizes = group.variants
.map((v) => v.storage_size_megabytes || 0)
.filter((s) => s > 0)
.sort((a, b) => a - b)}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{group.variants.length} variants{#if sizes.length >= 2}{" "}({formatSize(
sizes[0],
)}-{formatSize(sizes[sizes.length - 1])}){/if}
</span>
{/if}
<!-- Download availability indicator -->
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
<span
class="flex-shrink-0"
title={groupDownloadStatus.available
? `Ready — downloaded on ${groupDownloadStatus.nodeNames.join(", ")}`
: `Downloaded on ${groupDownloadStatus.nodeNames.join(", ")} (may need more nodes)`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
<!-- Check mark if selected (single-variant) -->
{#if isMainSelected}
<svg
class="w-4 h-4 text-exo-yellow flex-shrink-0"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z" />
</svg>
{/if}
<!-- Favorite star -->
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0"
onclick={(e) => {
e.stopPropagation();
onToggleFavorite(group.id);
}}
title={isFavorite ? "Remove from favorites" : "Add to favorites"}
>
{#if isFavorite}
<svg
class="w-4 h-4 text-amber-400"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{:else}
<svg
class="w-4 h-4 text-white/30 hover:text-white/50"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{/if}
</button>
<!-- Info button -->
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0"
onclick={(e) => {
e.stopPropagation();
onShowInfo(group);
}}
title="View model details"
>
<svg
class="w-4 h-4 text-white/30 hover:text-white/50"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-6h2v6zm0-8h-2V7h2v2z"
/>
</svg>
</button>
</div>
<!-- Expanded variants -->
{#if isExpanded && group.hasMultipleVariants}
<div class="bg-black/20 border-t border-white/5">
{#each group.variants as variant}
{@const modelCanFit = canModelFit(variant.id)}
{@const isSelected = selectedModelId === variant.id}
<button
type="button"
class="w-full flex items-center gap-3 px-3 py-2 pl-10 hover:bg-white/5 transition-colors text-left {!modelCanFit
? 'opacity-50 cursor-not-allowed'
: 'cursor-pointer'} {isSelected
? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'
: 'border-l-2 border-transparent'}"
disabled={!modelCanFit}
onclick={() => {
if (modelCanFit) {
onSelectModel(variant.id);
}
}}
>
<!-- Quantization badge -->
<span
class="text-xs font-mono px-1.5 py-0.5 rounded bg-white/10 text-white/70 flex-shrink-0"
>
{variant.quantization || "default"}
</span>
<!-- Size -->
<span class="text-xs font-mono text-white/40 flex-1">
{formatSize(variant.storage_size_megabytes)}
</span>
<!-- Download indicator for this variant -->
{#if downloadStatusMap?.get(variant.id)}
{@const variantDl = downloadStatusMap.get(variant.id)}
{#if variantDl}
<span
class="flex-shrink-0"
title={`Downloaded on ${variantDl.nodeNames.join(", ")}`}
>
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{/if}
<!-- Check mark if selected -->
{#if isSelected}
<svg
class="w-4 h-4 text-exo-yellow"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z"
/>
</svg>
{/if}
</button>
{/each}
</div>
{/if}
</div>

View File

@@ -1,882 +0,0 @@
<script lang="ts">
import { fade, fly } from "svelte/transition";
import { cubicOut } from "svelte/easing";
import FamilySidebar from "./FamilySidebar.svelte";
import ModelPickerGroup from "./ModelPickerGroup.svelte";
import ModelFilterPopover from "./ModelFilterPopover.svelte";
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
interface ModelInfo {
id: string;
name?: string;
storage_size_megabytes?: number;
base_model?: string;
quantization?: string;
supports_tensor?: boolean;
capabilities?: string[];
family?: string;
is_custom?: boolean;
tasks?: string[];
hugging_face_id?: string;
}
interface ModelGroup {
id: string;
name: string;
capabilities: string[];
family: string;
variants: ModelInfo[];
smallestVariant: ModelInfo;
hasMultipleVariants: boolean;
}
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
interface HuggingFaceModel {
id: string;
author: string;
downloads: number;
likes: number;
last_modified: string;
tags: string[];
}
type ModelPickerModalProps = {
isOpen: boolean;
models: ModelInfo[];
selectedModelId: string | null;
favorites: Set<string>;
existingModelIds: Set<string>;
canModelFit: (modelId: string) => boolean;
onSelect: (modelId: string) => void;
onClose: () => void;
onToggleFavorite: (baseModelId: string) => void;
onAddModel: (modelId: string) => Promise<void>;
onDeleteModel: (modelId: string) => Promise<void>;
totalMemoryGB: number;
usedMemoryGB: number;
downloadsData?: Record<string, unknown[]>;
topologyNodes?: Record<
string,
{
friendly_name?: string;
system_info?: { model_id?: string };
macmon_info?: { memory?: { ram_total?: number } };
}
>;
};
let {
isOpen,
models,
selectedModelId,
favorites,
existingModelIds,
canModelFit,
onSelect,
onClose,
onToggleFavorite,
onAddModel,
onDeleteModel,
totalMemoryGB,
usedMemoryGB,
downloadsData,
topologyNodes,
}: ModelPickerModalProps = $props();
// Local state
let searchQuery = $state("");
let selectedFamily = $state<string | null>(null);
let expandedGroups = $state<Set<string>>(new Set());
let showFilters = $state(false);
let filters = $state<FilterState>({
capabilities: [],
sizeRange: null,
downloadedOnly: false,
});
let infoGroup = $state<ModelGroup | null>(null);
// Download availability per model group
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
function getNodeName(nodeId: string): string {
const node = topologyNodes?.[nodeId];
return (
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
);
}
const modelDownloadAvailability = $derived.by(() => {
const result = new Map<string, DownloadAvailability>();
if (!downloadsData || !topologyNodes) return result;
for (const model of models) {
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
if (nodeIds.length === 0) continue;
// Sum total RAM across nodes that have the model
let totalRamBytes = 0;
for (const nodeId of nodeIds) {
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
}
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
result.set(model.id, {
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
nodeNames: nodeIds.map(getNodeName),
nodeIds,
});
}
return result;
});
// Aggregate download availability per group (available if ANY variant is available)
function getGroupDownloadAvailability(
group: ModelGroup,
): DownloadAvailability | undefined {
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) return avail;
}
return undefined;
}
// Get per-variant download map for a group
function getVariantDownloadMap(
group: ModelGroup,
): Map<string, DownloadAvailability> {
const map = new Map<string, DownloadAvailability>();
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) map.set(variant.id, avail);
}
return map;
}
// HuggingFace Hub state
let hfSearchQuery = $state("");
let hfSearchResults = $state<HuggingFaceModel[]>([]);
let hfTrendingModels = $state<HuggingFaceModel[]>([]);
let hfIsSearching = $state(false);
let hfIsLoadingTrending = $state(false);
let addingModelId = $state<string | null>(null);
let hfSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;
let manualModelId = $state("");
let addModelError = $state<string | null>(null);
// Reset transient state when modal opens, but preserve tab selection
$effect(() => {
if (isOpen) {
searchQuery = "";
expandedGroups = new Set();
showFilters = false;
manualModelId = "";
addModelError = null;
}
});
// Fetch trending models when HuggingFace is selected
$effect(() => {
if (
selectedFamily === "huggingface" &&
hfTrendingModels.length === 0 &&
!hfIsLoadingTrending
) {
fetchTrendingModels();
}
});
async function fetchTrendingModels() {
hfIsLoadingTrending = true;
try {
const response = await fetch("/models/search?query=&limit=20");
if (response.ok) {
hfTrendingModels = await response.json();
}
} catch (error) {
console.error("Failed to fetch trending models:", error);
} finally {
hfIsLoadingTrending = false;
}
}
async function searchHuggingFace(query: string) {
if (query.length < 2) {
hfSearchResults = [];
return;
}
hfIsSearching = true;
try {
const response = await fetch(
`/models/search?query=${encodeURIComponent(query)}&limit=20`,
);
if (response.ok) {
hfSearchResults = await response.json();
} else {
hfSearchResults = [];
}
} catch (error) {
console.error("Failed to search models:", error);
hfSearchResults = [];
} finally {
hfIsSearching = false;
}
}
function handleHfSearchInput(query: string) {
hfSearchQuery = query;
addModelError = null;
if (hfSearchDebounceTimer) {
clearTimeout(hfSearchDebounceTimer);
}
if (query.length >= 2) {
hfSearchDebounceTimer = setTimeout(() => {
searchHuggingFace(query);
}, 300);
} else {
hfSearchResults = [];
}
}
async function handleAddModel(modelId: string) {
addingModelId = modelId;
addModelError = null;
try {
await onAddModel(modelId);
} catch (error) {
addModelError =
error instanceof Error ? error.message : "Failed to add model";
} finally {
addingModelId = null;
}
}
async function handleAddManualModel() {
if (!manualModelId.trim()) return;
await handleAddModel(manualModelId.trim());
if (!addModelError) {
manualModelId = "";
}
}
function handleSelectHfModel(modelId: string) {
onSelect(modelId);
onClose();
}
// Models to display in HuggingFace view
const hfDisplayModels = $derived.by((): HuggingFaceModel[] => {
if (hfSearchQuery.length >= 2) {
return hfSearchResults;
}
return hfTrendingModels;
});
// Group models by base_model
const groupedModels = $derived.by((): ModelGroup[] => {
const groups = new Map<string, ModelGroup>();
for (const model of models) {
const groupId = model.base_model || model.id;
const groupName = model.base_model || model.name || model.id;
if (!groups.has(groupId)) {
groups.set(groupId, {
id: groupId,
name: groupName,
capabilities: model.capabilities || ["text"],
family: model.family || "",
variants: [],
smallestVariant: model,
hasMultipleVariants: false,
});
}
const group = groups.get(groupId)!;
group.variants.push(model);
// Track smallest variant
if (
(model.storage_size_megabytes || 0) <
(group.smallestVariant.storage_size_megabytes || Infinity)
) {
group.smallestVariant = model;
}
// Update capabilities if not set
if (
group.capabilities.length <= 1 &&
model.capabilities &&
model.capabilities.length > 1
) {
group.capabilities = model.capabilities;
}
if (!group.family && model.family) {
group.family = model.family;
}
}
// Sort variants within each group by size
for (const group of groups.values()) {
group.variants.sort(
(a, b) =>
(a.storage_size_megabytes || 0) - (b.storage_size_megabytes || 0),
);
group.hasMultipleVariants = group.variants.length > 1;
}
// Convert to array and sort by smallest variant size (biggest first)
return Array.from(groups.values()).sort((a, b) => {
return (
(b.smallestVariant.storage_size_megabytes || 0) -
(a.smallestVariant.storage_size_megabytes || 0)
);
});
});
// Get unique families
const uniqueFamilies = $derived.by((): string[] => {
const families = new Set<string>();
for (const group of groupedModels) {
if (group.family) {
families.add(group.family);
}
}
const familyOrder = [
"kimi",
"qwen",
"glm",
"minimax",
"deepseek",
"gpt-oss",
"llama",
];
return Array.from(families).sort((a, b) => {
const aIdx = familyOrder.indexOf(a);
const bIdx = familyOrder.indexOf(b);
if (aIdx === -1 && bIdx === -1) return a.localeCompare(b);
if (aIdx === -1) return 1;
if (bIdx === -1) return -1;
return aIdx - bIdx;
});
});
// Filter models based on search, family, and filters
const filteredGroups = $derived.by((): ModelGroup[] => {
let result: ModelGroup[] = [...groupedModels];
// Filter by family
if (selectedFamily === "favorites") {
result = result.filter((g) => favorites.has(g.id));
} else if (selectedFamily && selectedFamily !== "huggingface") {
result = result.filter((g) => g.family === selectedFamily);
}
// Filter by search query
if (searchQuery.trim()) {
const query = searchQuery.toLowerCase().trim();
result = result.filter(
(g) =>
g.name.toLowerCase().includes(query) ||
g.variants.some(
(v) =>
v.id.toLowerCase().includes(query) ||
(v.name || "").toLowerCase().includes(query),
),
);
}
// Filter by capabilities
if (filters.capabilities.length > 0) {
result = result.filter((g) =>
filters.capabilities.every((cap) => g.capabilities.includes(cap)),
);
}
// Filter by size range
if (filters.sizeRange) {
const { min, max } = filters.sizeRange;
result = result.filter((g) => {
const sizeGB = (g.smallestVariant.storage_size_megabytes || 0) / 1024;
return sizeGB >= min && sizeGB <= max;
});
}
// Filter to downloaded models only
if (filters.downloadedOnly) {
result = result.filter((g) =>
g.variants.some((v) => {
const avail = modelDownloadAvailability.get(v.id);
return avail && avail.nodeIds.length > 0;
}),
);
}
// Sort: models that fit first, then by size (largest first)
result.sort((a, b) => {
const aFits = a.variants.some((v) => canModelFit(v.id));
const bFits = b.variants.some((v) => canModelFit(v.id));
if (aFits && !bFits) return -1;
if (!aFits && bFits) return 1;
return (
(b.smallestVariant.storage_size_megabytes || 0) -
(a.smallestVariant.storage_size_megabytes || 0)
);
});
return result;
});
// Check if any favorites exist
const hasFavorites = $derived(favorites.size > 0);
function toggleGroupExpanded(groupId: string) {
const next = new Set(expandedGroups);
if (next.has(groupId)) {
next.delete(groupId);
} else {
next.add(groupId);
}
expandedGroups = next;
}
function handleSelect(modelId: string) {
onSelect(modelId);
onClose();
}
function handleKeydown(e: KeyboardEvent) {
if (e.key === "Escape") {
onClose();
}
}
function handleFiltersChange(newFilters: FilterState) {
filters = newFilters;
}
function clearFilters() {
filters = { capabilities: [], sizeRange: null, downloadedOnly: false };
}
const hasActiveFilters = $derived(
filters.capabilities.length > 0 ||
filters.sizeRange !== null ||
filters.downloadedOnly,
);
</script>
<svelte:window onkeydown={handleKeydown} />
{#if isOpen}
<!-- Backdrop -->
<div
class="fixed inset-0 z-50 bg-black/80 backdrop-blur-sm"
transition:fade={{ duration: 200 }}
onclick={onClose}
role="presentation"
></div>
<!-- Modal -->
<div
class="fixed z-50 top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(90vw,600px)] h-[min(80vh,700px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl overflow-hidden flex flex-col"
transition:fly={{ y: 20, duration: 300, easing: cubicOut }}
role="dialog"
aria-modal="true"
aria-label="Select a model"
>
<!-- Header with search -->
<div
class="flex items-center gap-2 p-3 border-b border-exo-yellow/10 bg-exo-medium-gray/30"
>
{#if selectedFamily === "huggingface"}
<!-- HuggingFace search -->
<svg
class="w-5 h-5 text-orange-400/60 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<circle cx="11" cy="11" r="8" />
<path d="M21 21l-4.35-4.35" />
</svg>
<input
type="search"
class="flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40"
placeholder="Search mlx-community models..."
value={hfSearchQuery}
oninput={(e) => handleHfSearchInput(e.currentTarget.value)}
/>
{#if hfIsSearching}
<div class="flex-shrink-0">
<span
class="w-4 h-4 border-2 border-orange-400 border-t-transparent rounded-full animate-spin block"
></span>
</div>
{/if}
{:else}
<!-- Normal model search -->
<svg
class="w-5 h-5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<circle cx="11" cy="11" r="8" />
<path d="M21 21l-4.35-4.35" />
</svg>
<input
type="search"
class="flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40"
placeholder="Search models..."
bind:value={searchQuery}
/>
<!-- Cluster memory -->
<span
class="text-xs font-mono flex-shrink-0"
title="Cluster memory usage"
><span class="text-exo-yellow">{Math.round(usedMemoryGB)}GB</span
><span class="text-white/40">/{Math.round(totalMemoryGB)}GB</span
></span
>
<!-- Filter button -->
<div class="relative filter-toggle">
<button
type="button"
class="p-1.5 rounded hover:bg-white/10 transition-colors {hasActiveFilters
? 'text-exo-yellow'
: 'text-white/50'}"
onclick={() => (showFilters = !showFilters)}
title="Filter by capability or size"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="currentColor">
<path d="M10 18h4v-2h-4v2zM3 6v2h18V6H3zm3 7h12v-2H6v2z" />
</svg>
</button>
{#if showFilters}
<ModelFilterPopover
{filters}
onChange={handleFiltersChange}
onClear={clearFilters}
onClose={() => (showFilters = false)}
/>
{/if}
</div>
{/if}
<!-- Close button -->
<button
type="button"
class="p-1.5 rounded hover:bg-white/10 transition-colors text-white/50 hover:text-white/70"
onclick={onClose}
title="Close model picker"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z"
/>
</svg>
</button>
</div>
<!-- Body -->
<div class="flex flex-1 overflow-hidden">
<!-- Family sidebar -->
<FamilySidebar
families={uniqueFamilies}
{selectedFamily}
{hasFavorites}
onSelect={(family) => (selectedFamily = family)}
/>
<!-- Model list -->
<div class="flex-1 overflow-y-auto flex flex-col">
{#if selectedFamily === "huggingface"}
<!-- HuggingFace Hub view -->
<div class="flex-1 flex flex-col min-h-0">
<!-- Section header -->
<div
class="sticky top-0 z-10 px-3 py-2 bg-exo-dark-gray/95 border-b border-exo-yellow/10"
>
<span class="text-xs font-mono text-white/40">
{#if hfSearchQuery.length >= 2}
Search results for "{hfSearchQuery}"
{:else}
Trending on mlx-community
{/if}
</span>
</div>
<!-- Results list -->
<div class="flex-1 overflow-y-auto">
{#if hfIsLoadingTrending && hfTrendingModels.length === 0}
<div
class="flex items-center justify-center py-12 text-white/40"
>
<span
class="w-5 h-5 border-2 border-orange-400 border-t-transparent rounded-full animate-spin mr-2"
></span>
<span class="font-mono text-sm"
>Loading trending models...</span
>
</div>
{:else if hfDisplayModels.length === 0}
<div
class="flex flex-col items-center justify-center py-12 text-white/40"
>
<svg
class="w-10 h-10 mb-2"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 13.5c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm4 0c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm2-4.5H8c0-2.21 1.79-4 4-4s4 1.79 4 4z"
/>
</svg>
<p class="font-mono text-sm">No models found</p>
{#if hfSearchQuery}
<p class="font-mono text-xs mt-1">
Try a different search term
</p>
{/if}
</div>
{:else}
{#each hfDisplayModels as model}
<HuggingFaceResultItem
{model}
isAdded={existingModelIds.has(model.id)}
isAdding={addingModelId === model.id}
onAdd={() => handleAddModel(model.id)}
onSelect={() => handleSelectHfModel(model.id)}
downloadedOnNodes={downloadsData
? getNodesWithModelDownloaded(
downloadsData,
model.id,
).map(getNodeName)
: []}
/>
{/each}
{/if}
</div>
<!-- Manual input footer -->
<div
class="sticky bottom-0 border-t border-exo-yellow/10 bg-exo-dark-gray p-3"
>
{#if addModelError}
<div
class="bg-red-500/10 border border-red-500/30 rounded px-3 py-2 mb-2"
>
<p class="text-red-400 text-xs font-mono break-words">
{addModelError}
</p>
</div>
{/if}
<div class="flex gap-2">
<input
type="text"
class="flex-1 bg-exo-black/60 border border-exo-yellow/30 rounded px-3 py-1.5 text-xs font-mono text-white placeholder-white/30 focus:outline-none focus:border-exo-yellow/50"
placeholder="Or paste model ID directly..."
bind:value={manualModelId}
onkeydown={(e) => {
if (e.key === "Enter") handleAddManualModel();
}}
/>
<button
type="button"
onclick={handleAddManualModel}
disabled={!manualModelId.trim() || addingModelId !== null}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded disabled:opacity-50 disabled:cursor-not-allowed"
>
Add
</button>
</div>
</div>
</div>
{:else if filteredGroups.length === 0}
<div
class="flex flex-col items-center justify-center h-full text-white/40 p-8"
>
<svg class="w-12 h-12 mb-3" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z"
/>
</svg>
<p class="font-mono text-sm">No models found</p>
{#if hasActiveFilters || searchQuery}
<button
type="button"
class="mt-2 text-xs text-exo-yellow hover:underline"
onclick={() => {
searchQuery = "";
clearFilters();
}}
>
Clear filters
</button>
{/if}
</div>
{:else}
{#each filteredGroups as group}
<ModelPickerGroup
{group}
isExpanded={expandedGroups.has(group.id)}
isFavorite={favorites.has(group.id)}
{selectedModelId}
{canModelFit}
onToggleExpand={() => toggleGroupExpanded(group.id)}
onSelectModel={handleSelect}
{onToggleFavorite}
onShowInfo={(g) => (infoGroup = g)}
downloadStatusMap={getVariantDownloadMap(group)}
/>
{/each}
{/if}
</div>
</div>
<!-- Footer with active filters indicator -->
{#if hasActiveFilters}
<div
class="flex items-center gap-2 px-3 py-2 border-t border-exo-yellow/10 bg-exo-medium-gray/20 text-xs font-mono text-white/50"
>
<span>Filters:</span>
{#each filters.capabilities as cap}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded"
>{cap}</span
>
{/each}
{#if filters.downloadedOnly}
<span class="px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded"
>Downloaded</span
>
{/if}
{#if filters.sizeRange}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
</span>
{/if}
<button
type="button"
class="ml-auto text-white/40 hover:text-white/60"
onclick={clearFilters}
>
Clear all
</button>
</div>
{/if}
</div>
<!-- Info modal -->
{#if infoGroup}
<div
class="fixed inset-0 z-[60] bg-black/60"
transition:fade={{ duration: 150 }}
onclick={() => (infoGroup = null)}
role="presentation"
></div>
<div
class="fixed z-[60] top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(80vw,400px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl p-4"
transition:fly={{ y: 10, duration: 200, easing: cubicOut }}
role="dialog"
aria-modal="true"
>
<div class="flex items-start justify-between mb-3">
<h3 class="font-mono text-lg text-white">{infoGroup.name}</h3>
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors text-white/50"
onclick={() => (infoGroup = null)}
title="Close model details"
aria-label="Close info dialog"
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z"
/>
</svg>
</button>
</div>
<div class="space-y-2 text-xs font-mono">
<div class="flex items-center gap-2">
<span class="text-white/40">Family:</span>
<span class="text-white/70">{infoGroup.family || "Unknown"}</span>
</div>
<div class="flex items-center gap-2">
<span class="text-white/40">Capabilities:</span>
<span class="text-white/70">{infoGroup.capabilities.join(", ")}</span>
</div>
<div class="flex items-center gap-2">
<span class="text-white/40">Variants:</span>
<span class="text-white/70">{infoGroup.variants.length}</span>
</div>
{#if infoGroup.variants.length > 0}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<span class="text-white/40">Available quantizations:</span>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoGroup.variants as variant}
<span
class="px-1.5 py-0.5 bg-white/10 text-white/60 rounded text-[10px]"
>
{variant.quantization || "default"} ({Math.round(
(variant.storage_size_megabytes || 0) / 1024,
)}GB)
</span>
{/each}
</div>
</div>
{/if}
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
{#if infoDownload}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<div class="flex items-center gap-2 mb-1">
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="text-white/40">Downloaded on:</span>
</div>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoDownload.nodeNames as nodeName}
<span
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
>
{nodeName}
</span>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>
{/if}
{/if}

View File

@@ -1,236 +0,0 @@
<script lang="ts">
import type { TokenData } from "$lib/stores/app.svelte";
interface Props {
tokens: TokenData[];
class?: string;
isGenerating?: boolean;
onRegenerateFrom?: (tokenIndex: number) => void;
}
let {
tokens,
class: className = "",
isGenerating = false,
onRegenerateFrom,
}: Props = $props();
// Tooltip state - track both token data and index
let hoveredTokenIndex = $state<number | null>(null);
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
let isTooltipHovered = $state(false);
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
// Derive the hovered token from the index (stable across re-renders)
const hoveredToken = $derived(
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
? {
token: tokens[hoveredTokenIndex],
index: hoveredTokenIndex,
...hoveredPosition,
}
: null,
);
/**
* Get confidence styling based on probability.
* Following Apple design principles: high confidence tokens blend in,
* only uncertainty draws attention.
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return "text-inherit"; // Expected tokens - blend in
if (probability > 0.5) return "bg-gray-500/10 text-inherit"; // Slight hint
if (probability > 0.2) return "bg-amber-500/15 text-amber-200/90"; // Subtle warmth
return "bg-red-500/20 text-red-200/90"; // Draws attention
}
/**
* Get border/underline styling for uncertain tokens
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return "border-transparent"; // No border for expected
if (probability > 0.5) return "border-gray-500/20";
if (probability > 0.2) return "border-amber-500/30";
return "border-red-500/40";
}
function clearHideTimeout() {
if (hideTimeoutId) {
clearTimeout(hideTimeoutId);
hideTimeoutId = null;
}
}
function handleMouseEnter(
event: MouseEvent,
token: TokenData,
index: number,
) {
clearHideTimeout();
const rects = (event.target as HTMLElement).getClientRects();
let rect = rects[0];
for (let j = 0; j < rects.length; j++) {
if (event.clientY >= rects[j].top && event.clientY <= rects[j].bottom) {
rect = rects[j];
break;
}
}
hoveredTokenIndex = index;
hoveredPosition = {
x: rect.left + rect.width / 2,
y: rect.top - 10,
};
}
function handleMouseLeave() {
clearHideTimeout();
// Use longer delay during generation to account for re-renders
const delay = isGenerating ? 300 : 200;
hideTimeoutId = setTimeout(() => {
if (!isTooltipHovered) {
hoveredTokenIndex = null;
hoveredPosition = null;
}
}, delay);
}
function handleTooltipEnter() {
clearHideTimeout();
isTooltipHovered = true;
}
function handleTooltipLeave() {
isTooltipHovered = false;
hoveredTokenIndex = null;
hoveredPosition = null;
}
function handleRegenerate() {
if (hoveredToken && onRegenerateFrom) {
const indexToRegenerate = hoveredToken.index;
// Clear hover state immediately
hoveredTokenIndex = null;
hoveredPosition = null;
isTooltipHovered = false;
// Call regenerate
onRegenerateFrom(indexToRegenerate);
}
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + "%";
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
function getProbabilityColor(probability: number): string {
if (probability > 0.8) return "text-gray-300";
if (probability > 0.5) return "text-gray-400";
if (probability > 0.2) return "text-amber-400";
return "text-red-400";
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(
tokenData.probability,
)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
onmouseleave={handleMouseLeave}>{tokenData.token}</span
>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50 pb-2"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
onmouseenter={handleTooltipEnter}
onmouseleave={handleTooltipLeave}
>
<div
class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48"
>
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-500 text-xs">Token:</span>
<span class="text-white font-mono ml-1"
>"{hoveredToken.token.token}"</span
>
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2"
>{formatProbability(hoveredToken.token.probability)}</span
>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono"
>{formatLogprob(hoveredToken.token.logprob)}</span
>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700/50 mt-2 pt-2">
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24"
>"{alt.token}"</span
>
<span class="text-gray-400 ml-2"
>{formatProbability(altProb)}</span
>
</div>
{/each}
</div>
{/if}
<!-- Regenerate button -->
{#if onRegenerateFrom}
<button
onclick={handleRegenerate}
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
/>
</svg>
Regenerate from here
</button>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -6,9 +6,3 @@ export { default as ChatSidebar } from "./ChatSidebar.svelte";
export { default as ModelCard } from "./ModelCard.svelte";
export { default as MarkdownContent } from "./MarkdownContent.svelte";
export { default as ImageParamsPanel } from "./ImageParamsPanel.svelte";
export { default as FamilyLogos } from "./FamilyLogos.svelte";
export { default as FamilySidebar } from "./FamilySidebar.svelte";
export { default as HuggingFaceResultItem } from "./HuggingFaceResultItem.svelte";
export { default as ModelFilterPopover } from "./ModelFilterPopover.svelte";
export { default as ModelPickerGroup } from "./ModelPickerGroup.svelte";
export { default as ModelPickerModal } from "./ModelPickerModal.svelte";

View File

@@ -242,19 +242,6 @@ export interface MessageAttachment {
mimeType?: string;
}
export interface TopLogprob {
token: string;
logprob: number;
bytes: number[] | null;
}
export interface TokenData {
token: string;
logprob: number;
probability: number;
topLogprobs: TopLogprob[];
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -266,7 +253,6 @@ export interface Message {
tps?: number; // Tokens per second (for assistant messages)
requestType?: "chat" | "image-generation" | "image-editing";
sourceImageDataUrl?: string; // For image editing regeneration
tokens?: TokenData[];
}
export interface Conversation {
@@ -554,18 +540,7 @@ class AppStore {
*/
private saveConversationsToStorage() {
try {
// Strip tokens from messages before saving to avoid bloating localStorage
const stripped = this.conversations.map((conv) => ({
...conv,
messages: conv.messages.map((msg) => {
if (msg.tokens) {
const { tokens: _, ...rest } = msg;
return rest;
}
return msg;
}),
}));
localStorage.setItem(STORAGE_KEY, JSON.stringify(stripped));
localStorage.setItem(STORAGE_KEY, JSON.stringify(this.conversations));
} catch (error) {
console.error("Failed to save conversations:", error);
}
@@ -1470,213 +1445,6 @@ class AppStore {
}
}
/**
* Regenerate response from a specific token index.
* Truncates the assistant message at the given token and re-generates from there.
*/
async regenerateFromToken(
messageId: string,
tokenIndex: number,
): Promise<void> {
if (this.isLoading) return;
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
const msgIndex = this.messages.findIndex((m) => m.id === messageId);
if (msgIndex === -1) return;
const msg = this.messages[msgIndex];
if (
msg.role !== "assistant" ||
!msg.tokens ||
tokenIndex >= msg.tokens.length
)
return;
// Keep tokens up to (not including) the specified index
const tokensToKeep = msg.tokens.slice(0, tokenIndex);
const prefixText = tokensToKeep.map((t) => t.token).join("");
// Remove all messages after this assistant message
this.messages = this.messages.slice(0, msgIndex + 1);
// Update the message to show the prefix
this.messages[msgIndex].content = prefixText;
this.messages[msgIndex].tokens = tokensToKeep;
this.updateActiveConversation();
// Set up for continuation - modify the existing message in place
this.isLoading = true;
this.currentResponse = prefixText;
this.ttftMs = null;
this.tps = null;
this.totalTokens = tokensToKeep.length;
try {
// Build messages for API - include the partial assistant message
const systemPrompt = {
role: "system" as const,
content:
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
};
const apiMessages = [
systemPrompt,
...this.messages.map((m) => {
let msgContent = m.content;
if (m.attachments) {
for (const attachment of m.attachments) {
if (attachment.type === "text" && attachment.content) {
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
}
}
}
return { role: m.role, content: msgContent };
}),
];
const modelToUse = this.getModelForRequest();
if (!modelToUse) {
throw new Error("No model available");
}
const requestStartTime = performance.now();
let firstTokenTime: number | null = null;
let tokenCount = tokensToKeep.length;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`API error: ${response.status} - ${errorText}`);
}
const reader = response.body?.getReader();
if (!reader) throw new Error("No response body");
let fullContent = prefixText;
const collectedTokens: TokenData[] = [...tokensToKeep];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
}
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
if (delta) {
if (firstTokenTime === null) {
firstTokenTime = performance.now();
this.ttftMs = firstTokenTime - requestStartTime;
}
tokenCount += 1;
this.totalTokens = tokenCount;
if (firstTokenTime !== null && tokenCount > tokensToKeep.length) {
const elapsed = performance.now() - firstTokenTime;
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
}
fullContent += delta;
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
if (this.activeConversationId === targetConversationId) {
this.currentResponse = displayContent;
}
// Update existing message in place
this.updateConversationMessage(
targetConversationId,
messageId,
(m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
},
);
// Final update
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.tokens = [...collectedTokens];
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
if (this.tps !== null) m.tps = this.tps;
});
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
} catch (error) {
console.error("Error regenerating from token:", error);
if (this.conversationExists(targetConversationId)) {
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = `${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
});
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
} finally {
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
}
}
/**
* Helper method to regenerate a chat completion response
*/
@@ -1745,8 +1513,6 @@ class AppStore {
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -1761,49 +1527,16 @@ class AppStore {
}
let streamedContent = "";
const collectedTokens: TokenData[] = [];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
choices?: Array<{ delta?: { content?: string } }>;
}
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
const delta = parsed.choices?.[0]?.delta?.content;
if (delta) {
streamedContent += delta;
const { displayContent, thinkingContent } =
@@ -1821,7 +1554,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -1840,7 +1572,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -2183,8 +1914,6 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -2201,48 +1930,14 @@ class AppStore {
let streamedContent = "";
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
choices?: Array<{ delta?: { content?: string } }>;
}
const collectedTokens: TokenData[] = [];
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
const tokenContent = parsed.choices?.[0]?.delta?.content;
if (tokenContent) {
// Track first token for TTFT
if (firstTokenTime === null) {
@@ -2278,7 +1973,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -2303,7 +1997,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
// Store performance metrics on the message
if (this.ttftMs !== null) {
msg.ttftMs = this.ttftMs;
@@ -3000,8 +2693,6 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
appStore.regenerateFromToken(messageId, tokenIndex);
// Conversation actions
export const conversations = () => appStore.conversations;

View File

@@ -1,97 +0,0 @@
/**
* FavoritesStore - Manages favorite models with localStorage persistence
*/
import { browser } from "$app/environment";
const FAVORITES_KEY = "exo-favorite-models";
class FavoritesStore {
favorites = $state<Set<string>>(new Set());
constructor() {
if (browser) {
this.loadFromStorage();
}
}
private loadFromStorage() {
try {
const stored = localStorage.getItem(FAVORITES_KEY);
if (stored) {
const parsed = JSON.parse(stored) as string[];
this.favorites = new Set(parsed);
}
} catch (error) {
console.error("Failed to load favorites:", error);
}
}
private saveToStorage() {
try {
const array = Array.from(this.favorites);
localStorage.setItem(FAVORITES_KEY, JSON.stringify(array));
} catch (error) {
console.error("Failed to save favorites:", error);
}
}
add(baseModelId: string) {
const next = new Set(this.favorites);
next.add(baseModelId);
this.favorites = next;
this.saveToStorage();
}
remove(baseModelId: string) {
const next = new Set(this.favorites);
next.delete(baseModelId);
this.favorites = next;
this.saveToStorage();
}
toggle(baseModelId: string) {
if (this.favorites.has(baseModelId)) {
this.remove(baseModelId);
} else {
this.add(baseModelId);
}
}
isFavorite(baseModelId: string): boolean {
return this.favorites.has(baseModelId);
}
getAll(): string[] {
return Array.from(this.favorites);
}
getSet(): Set<string> {
return new Set(this.favorites);
}
hasAny(): boolean {
return this.favorites.size > 0;
}
clearAll() {
this.favorites = new Set();
this.saveToStorage();
}
}
export const favoritesStore = new FavoritesStore();
export const favorites = () => favoritesStore.favorites;
export const hasFavorites = () => favoritesStore.hasAny();
export const isFavorite = (baseModelId: string) =>
favoritesStore.isFavorite(baseModelId);
export const toggleFavorite = (baseModelId: string) =>
favoritesStore.toggle(baseModelId);
export const addFavorite = (baseModelId: string) =>
favoritesStore.add(baseModelId);
export const removeFavorite = (baseModelId: string) =>
favoritesStore.remove(baseModelId);
export const getFavorites = () => favoritesStore.getAll();
export const getFavoritesSet = () => favoritesStore.getSet();
export const clearFavorites = () => favoritesStore.clearAll();

View File

@@ -1,152 +0,0 @@
/**
* Shared utilities for parsing and querying download state.
*
* The download state from `/state` is shaped as:
* Record<NodeId, Array<TaggedDownloadEntry>>
*
* Each entry is a tagged union object like:
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
*/
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
function unwrapTagged(
obj: Record<string, unknown>,
): [string, Record<string, unknown>] | null {
const keys = Object.keys(obj);
if (keys.length !== 1) return null;
const tag = keys[0];
const payload = obj[tag];
if (!payload || typeof payload !== "object") return null;
return [tag, payload as Record<string, unknown>];
}
/** Extract the model ID string from a download entry's nested shard_metadata. */
export function extractModelIdFromDownload(
downloadPayload: Record<string, unknown>,
): string | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
if (!unwrapped) return null;
const [, shardData] = unwrapped;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== "object") return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
}
/** Extract the shard_metadata object from a download entry payload. */
export function extractShardMetadata(
downloadPayload: Record<string, unknown>,
): Record<string, unknown> | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
return shardMetadata as Record<string, unknown>;
}
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
export function getDownloadTag(
entry: unknown,
): [string, Record<string, unknown>] | null {
if (!entry || typeof entry !== "object") return null;
return unwrapTagged(entry as Record<string, unknown>);
}
/**
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
*/
function* iterNodeDownloads(
nodeDownloads: unknown[],
): Generator<[string, Record<string, unknown>, string]> {
for (const entry of nodeDownloads) {
const tagged = getDownloadTag(entry);
if (!tagged) continue;
const [tag, payload] = tagged;
const modelId = extractModelIdFromDownload(payload);
if (!modelId) continue;
yield [tag, payload, modelId];
}
}
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
export function isModelDownloadedOnNode(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): boolean {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return false;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
}
return false;
}
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
export function getNodesWithModelDownloaded(
downloadsData: Record<string, unknown[]>,
modelId: string,
): string[] {
const result: string[] = [];
for (const nodeId of Object.keys(downloadsData)) {
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
result.push(nodeId);
}
}
return result;
}
/**
* Find shard metadata for a model from any download entry across all nodes.
* Returns the first match found (completed entries are preferred).
*/
export function getShardMetadataForModel(
downloadsData: Record<string, unknown[]>,
modelId: string,
): Record<string, unknown> | null {
let fallback: Record<string, unknown> | null = null;
for (const nodeDownloads of Object.values(downloadsData)) {
if (!Array.isArray(nodeDownloads)) continue;
for (const [tag, payload, entryModelId] of iterNodeDownloads(
nodeDownloads,
)) {
if (entryModelId !== modelId) continue;
const shard = extractShardMetadata(payload);
if (!shard) continue;
if (tag === "DownloadCompleted") return shard;
if (!fallback) fallback = shard;
}
}
return fallback;
}
/**
* Get the download status tag for a specific model on a specific node.
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
*/
export function getModelDownloadStatus(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): string | null {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return null;
let best: string | null = null;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (entryModelId !== modelId) continue;
if (tag === "DownloadCompleted") return tag;
if (tag === "DownloadOngoing") best = tag;
else if (!best) best = tag;
}
return best;
}

View File

@@ -5,13 +5,7 @@
ChatMessages,
ChatSidebar,
ModelCard,
ModelPickerModal,
} from "$lib/components";
import {
favorites,
toggleFavorite,
getFavoritesSet,
} from "$lib/stores/favorites.svelte";
import {
hasStartedChat,
isTopologyMinimized,
@@ -106,11 +100,6 @@
storage_size_megabytes?: number;
tasks?: string[];
hugging_face_id?: string;
is_custom?: boolean;
family?: string;
quantization?: string;
base_model?: string;
capabilities?: string[];
}>
>([]);
@@ -222,11 +211,9 @@
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Model picker modal state
let isModelPickerOpen = $state(false);
// Favorites state (reactive)
const favoritesSet = $derived(getFavoritesSet());
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state("");
// Slider dragging state
let isDraggingSlider = $state(false);
@@ -543,47 +530,6 @@
}
}
async function addModelFromPicker(modelId: string) {
const response = await fetch("/models/add", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ model_id: modelId }),
});
if (!response.ok) {
let message = `Failed to add model (${response.status}: ${response.statusText})`;
try {
const err = await response.json();
if (err.detail) message = err.detail;
} catch {
// use default message
}
throw new Error(message);
}
await fetchModels();
}
async function deleteCustomModel(modelId: string) {
try {
const response = await fetch(
`/models/custom/${encodeURIComponent(modelId)}`,
{ method: "DELETE" },
);
if (response.ok) {
await fetchModels();
}
} catch {
console.error("Failed to delete custom model");
}
}
function handleModelPickerSelect(modelId: string) {
selectPreviewModel(modelId);
saveLaunchDefaults();
isModelPickerOpen = false;
}
async function launchInstance(
modelId: string,
specificPreview?: PlacementPreview | null,
@@ -2414,12 +2360,14 @@
>
</div>
<!-- Model Picker Button -->
<div class="flex-shrink-0 mb-3">
<!-- Model Dropdown (Custom) -->
<div class="flex-shrink-0 mb-3 relative">
<button
type="button"
onclick={() => (isModelPickerOpen = true)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 relative"
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
{#if selectedModelId}
{@const foundModel = models.find(
@@ -2427,12 +2375,54 @@
)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
{@const isImageModel = modelSupportsImageGeneration(
foundModel.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
foundModel.id,
)}
<span
class="flex items-center justify-between gap-2 w-full pr-4"
>
<span
class="flex items-center gap-2 text-exo-light-gray truncate"
>
{#if isImageModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect
x="3"
y="3"
width="18"
height="18"
rx="2"
ry="2"
/>
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate"
>{foundModel.name || foundModel.id}</span
>
@@ -2449,24 +2439,142 @@
{:else}
<span class="text-white/50"> SELECT MODEL </span>
{/if}
<div
class="absolute right-3 top-1/2 -translate-y-1/2 pointer-events-none"
>
<svg
class="w-4 h-4 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</button>
<div
class="absolute right-3 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isModelDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-4 h-4 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
{#if isModelDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-40 cursor-default"
onclick={() => (isModelDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel -->
<div
class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-50 max-h-64 overflow-y-auto"
>
<!-- Search within dropdown -->
<div
class="sticky top-0 bg-exo-dark-gray border-b border-exo-medium-gray/30 p-2"
>
<input
type="text"
placeholder="Search models..."
bind:value={modelDropdownSearch}
class="w-full bg-exo-dark-gray/60 border border-exo-medium-gray/30 rounded px-2 py-1.5 text-xs font-mono text-white/80 placeholder:text-white/40 focus:outline-none focus:border-exo-yellow/50"
/>
</div>
<!-- Options -->
<div class="py-1">
{#each sortedModels().filter((m) => !modelDropdownSearch || (m.name || m.id)
.toLowerCase()
.includes(modelDropdownSearch.toLowerCase())) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(
model.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
model.id,
)}
<button
type="button"
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = "";
}
}}
disabled={!modelCanFit}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedModelId ===
model.id
? 'bg-transparent text-exo-yellow cursor-pointer'
: modelCanFit
? 'text-white/80 hover:text-exo-yellow cursor-pointer'
: 'text-white/30 cursor-default'}"
>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image generation model"
>
<rect
x="3"
y="3"
width="18"
height="18"
rx="2"
ry="2"
/>
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image editing model"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span
class="flex-shrink-0 text-xs {modelCanFit
? 'text-white/50'
: 'text-red-400/60'}"
>
{sizeGB >= 1
? sizeGB.toFixed(0)
: sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">
No models found
</div>
{/each}
</div>
</div>
{/if}
</div>
<!-- Configuration Options -->
@@ -3246,24 +3354,3 @@
{/if}
</main>
</div>
<ModelPickerModal
isOpen={isModelPickerOpen}
{models}
{selectedModelId}
favorites={favoritesSet}
existingModelIds={new Set(models.map((m) => m.id))}
canModelFit={(modelId) => {
const model = models.find((m) => m.id === modelId);
return model ? hasEnoughMemory(model) : false;
}}
onSelect={handleModelPickerSelect}
onClose={() => (isModelPickerOpen = false)}
onToggleFavorite={toggleFavorite}
onAddModel={addModelFromPicker}
onDeleteModel={deleteCustomModel}
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
{downloadsData}
topologyNodes={data?.nodes}
/>

View File

@@ -118,10 +118,9 @@
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
inherit (self'.packages) metal-toolchain;
metal-toolchain = self'.packages.metal-toolchain;
inherit uvLockMlxVersion;
};
default = self'.packages.exo;
}
);

View File

@@ -10,7 +10,6 @@ PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
RESOURCES_DIR = PROJECT_ROOT / "resources"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
@@ -19,9 +18,6 @@ if not ENTRYPOINT.is_file():
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not RESOURCES_DIR.is_dir():
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
@@ -62,7 +58,6 @@ HIDDEN_IMPORTS = sorted(
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(RESOURCES_DIR), "resources"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]

View File

@@ -26,7 +26,7 @@ dependencies = [
"httpx>=0.28.1",
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux==0.15.5",
"mflux==0.15.4",
"python-multipart>=0.0.21",
]

View File

@@ -69,8 +69,7 @@
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
--set EXO_RESOURCES_DIR ${inputs.self + "/resources"} \
--set DASHBOARD_DIR ${self'.packages.dashboard} \
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
done
'';

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Kontext-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Kontext-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Kontext-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15470210592
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5945589280
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21415799872
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11891178560
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33306978432
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23782357120
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,36 +0,0 @@
model_id = "exolabs/Qwen-Image-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
uses_cfg = true
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,36 +0,0 @@
model_id = "exolabs/Qwen-Image-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
uses_cfg = true
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,36 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
uses_cfg = true
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,36 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
uses_cfg = true
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,36 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
uses_cfg = true
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,36 +0,0 @@
model_id = "exolabs/Qwen-Image"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
uses_cfg = true
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/DeepSeek-V3.1-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405874409472

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/DeepSeek-V3.1-8bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 765577920512

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-4.5-Air-8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 122406567936

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-4.5-Air-bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 229780750336

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-4.7-4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 198556925568

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-4.7-6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 286737579648

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-4.7-8bit-gs32"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 396963397248

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-4bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 19327352832

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-5bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 22548578304

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-6bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 26843545600

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-8bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 34359738368

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = "4bit"
base_model = "Kimi K2"
capabilities = ["text"]
[storage_size]
in_bytes = 620622774272

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Kimi-K2-Thinking"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 706522120192

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Kimi-K2.5"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 662498705408

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.2 1B"
capabilities = ["text"]
[storage_size]
in_bytes = 729808896

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.2 3B"
capabilities = ["text"]
[storage_size]
in_bytes = 1863319552

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.2 3B"
capabilities = ["text"]
[storage_size]
in_bytes = 3501195264

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 40652242944

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 76799803392

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.1 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 40652242944

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 4637851648

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 8954839040

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "bf16"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 16882073600

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/MiniMax-M2.1-3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 100086644736

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/MiniMax-M2.1-8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-0.6B-4bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 342884352

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-0.6B-8bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 698351616

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 141733920768

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 268435456000

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 17612931072

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 33279705088

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Coder 480B"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 289910292480

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Coder 480B"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 579820584960

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 45644286500

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 57657697020

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 68899327465

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 89357758772

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 157548627945

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text"]
[storage_size]
in_bytes = 46976204800

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text"]
[storage_size]
in_bytes = 88814387200

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 47080074240

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 88814387200

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
n_layers = 36
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
family = "gpt-oss"
quantization = "MXFP4-Q8"
base_model = "GPT-OSS 120B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 70652212224

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
n_layers = 24
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
family = "gpt-oss"
quantization = "MXFP4-Q8"
base_model = "GPT-OSS 20B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 12025908224

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "fp16"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 144383672320

View File

@@ -1,5 +1,4 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
@@ -53,44 +52,18 @@ class DownloadCoordinator:
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
_tg: TaskGroup = field(init=False)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
async def _check_internet_connection(self) -> None:
first_connection = True
while True:
await asyncio.sleep(10)
# Assume that internet connection is set to False on 443 errors.
if self.shard_downloader.internet_connection:
continue
self._test_internet_connection()
if first_connection and self.shard_downloader.internet_connection:
first_connection = False
self._tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
@@ -268,7 +241,7 @@ class DownloadCoordinator:
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug(
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
@@ -301,10 +274,10 @@ class DownloadCoordinator:
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug(
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(60)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"

View File

@@ -49,10 +49,6 @@ class HuggingFaceAuthenticationError(Exception):
"""Raised when HuggingFace returns 401/403 for a model download."""
class HuggingFaceRateLimitError(Exception):
"""429 Huggingface code"""
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
token = await get_hf_token()
if status_code == 401 and token is None:
@@ -158,76 +154,49 @@ async def seed_models(seed_dir: str | Path):
logger.error(traceback.format_exc())
_fetched_file_lists_this_session: set[str] = set()
async def fetch_file_list_with_cache(
model_id: ModelId,
revision: str = "main",
recursive: bool = False,
skip_internet: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
model_id: ModelId, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
cache_key = f"{model_id.normalize()}--{revision}"
if cache_key in _fetched_file_lists_this_session and await aios.path.exists(
cache_file
):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
if skip_internet:
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
raise FileNotFoundError(
f"No internet connection and no cached file list for {model_id}"
)
# Always try fresh first
try:
file_list = await fetch_file_list_with_retry(
model_id,
revision,
recursive=recursive,
on_connection_lost=on_connection_lost,
model_id, revision, recursive=recursive
)
# Update cache with fresh data
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
_fetched_file_lists_this_session.add(cache_key)
return file_list
except Exception as e:
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
# No cache available, propagate the error
raise
async def fetch_file_list_with_retry(
model_id: ModelId,
revision: str = "main",
path: str = "",
recursive: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 3
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(2.0**attempt)
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
@@ -247,11 +216,7 @@ async def _fetch_file_list(
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
raise HuggingFaceAuthenticationError(msg)
elif response.status == 429:
raise HuggingFaceRateLimitError(
f"Couldn't download {model_id} because of HuggingFace rate limit."
)
elif response.status == 200:
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
@@ -284,7 +249,7 @@ def create_http_session(
else:
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 60
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(
@@ -359,9 +324,8 @@ async def download_file_with_retry(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> Path:
n_attempts = 3
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _download_file(
@@ -369,19 +333,14 @@ async def download_file_with_retry(
)
except HuggingFaceAuthenticationError:
raise
except HuggingFaceRateLimitError as e:
if attempt == n_attempts - 1:
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(2.0**attempt)
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
break
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
@@ -583,9 +542,7 @@ async def download_shard(
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
max_parallel_downloads: int = 8,
skip_download: bool = False,
skip_internet: bool = False,
allow_patterns: list[str] | None = None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.debug(f"Downloading {shard.model_card.model_id=}")
@@ -605,11 +562,7 @@ async def download_shard(
all_start_time = time.time()
file_list = await fetch_file_list_with_cache(
shard.model_card.model_id,
revision,
recursive=True,
skip_internet=skip_internet,
on_connection_lost=on_connection_lost,
shard.model_card.model_id, revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -719,7 +672,6 @@ async def download_shard(
lambda curr_bytes, total_bytes, is_renamed: schedule_progress(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
)
if not skip_download:

View File

@@ -1,5 +1,4 @@
import asyncio
from asyncio import create_task
from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
@@ -8,7 +7,7 @@ from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -50,10 +49,6 @@ class SingletonShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -90,10 +85,6 @@ class CachedShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -151,8 +142,6 @@ class ResumableShardDownloader(ShardDownloader):
self.on_progress_wrapper,
max_parallel_downloads=self.max_parallel_downloads,
allow_patterns=allow_patterns,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
)
return target_dir
@@ -165,24 +154,13 @@ class ResumableShardDownloader(ShardDownloader):
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)
return await download_shard(
shard,
self.on_progress_wrapper,
skip_download=True,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
shard, self.on_progress_wrapper, skip_download=True
)
semaphore = asyncio.Semaphore(self.max_parallel_downloads)
async def download_with_semaphore(
model_card: ModelCard,
) -> tuple[Path, RepoDownloadProgress]:
async with semaphore:
return await _status_for_model(model_card.model_id)
# Kick off download status coroutines concurrently
tasks = [
create_task(download_with_semaphore(model_card))
for model_card in await get_model_cards()
asyncio.create_task(_status_for_model(model_card.model_id))
for model_card in MODEL_CARDS.values()
]
for task in asyncio.as_completed(tasks):

View File

@@ -16,11 +16,6 @@ from exo.shared.types.worker.shards import (
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
class ShardDownloader(ABC):
internet_connection: bool = False
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
@abstractmethod
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False

View File

@@ -27,6 +27,7 @@ from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.main import Worker
# I marked this as a dataclass as I want trivial constructors.
@dataclass
class Node:
router: Router
@@ -135,6 +136,7 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -146,8 +148,6 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit

View File

@@ -14,8 +14,6 @@ from exo.shared.types.api import (
ErrorInfo,
ErrorResponse,
FinishReason,
Logprobs,
LogprobsContentItem,
StreamingChoiceResponse,
ToolCall,
)
@@ -68,9 +66,7 @@ def chat_request_to_text_generation(
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
@@ -83,8 +79,6 @@ def chat_request_to_text_generation(
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
logprobs=request.logprobs or False,
top_logprobs=request.top_logprobs,
)
@@ -92,19 +86,6 @@ def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -113,7 +94,6 @@ def chunk_to_response(
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
@@ -180,7 +160,6 @@ async def collect_chat_response(
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
logprobs_content: list[LogprobsContentItem] = []
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
@@ -195,14 +174,6 @@ async def collect_chat_response(
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
@@ -235,9 +206,6 @@ async def collect_chat_response(
content=combined_text,
tool_calls=tool_calls if tool_calls else None,
),
logprobs=Logprobs(content=logprobs_content)
if logprobs_content
else None,
finish_reason=finish_reason,
)
],

View File

@@ -141,9 +141,7 @@ def claude_request_to_text_generation(
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,

View File

@@ -43,10 +43,10 @@ def _extract_content(content: str | list[ResponseContentPart]) -> str:
def responses_request_to_text_generation(
request: ResponsesRequest,
) -> TextGenerationTaskParams:
input_value: list[InputMessage]
input_value: str | list[InputMessage]
built_chat_template: list[dict[str, Any]] | None = None
if isinstance(request.input, str):
input_value = [InputMessage(role="user", content=request.input)]
input_value = request.input
else:
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
@@ -95,11 +95,7 @@ def responses_request_to_text_generation(
}
)
input_value = (
input_messages
if input_messages
else [InputMessage(role="user", content="")]
)
input_value = input_messages if input_messages else ""
built_chat_template = chat_template_messages if chat_template_messages else None
return TextGenerationTaskParams(

View File

@@ -1,7 +1,6 @@
import base64
import contextlib
import json
import random
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from datetime import datetime, timezone
@@ -41,7 +40,6 @@ from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import (
DASHBOARD_DIR,
EXO_IMAGE_CACHE_DIR,
EXO_MAX_CHUNK_SIZE,
EXO_TRACING_CACHE_DIR,
@@ -49,15 +47,12 @@ from exo.shared.constants import (
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
delete_custom_card,
get_model_cards,
is_custom_card,
)
from exo.shared.tracing import TraceEvent, compute_stats, export_trace, load_trace_file
from exo.shared.types.api import (
AddCustomModelParams,
AdvancedImageParams,
BenchChatCompletionRequest,
BenchChatCompletionResponse,
@@ -75,7 +70,6 @@ from exo.shared.types.api import (
ErrorResponse,
FinishReason,
GenerationStats,
HuggingFaceSearchResult,
ImageData,
ImageEditsTaskParams,
ImageGenerationResponse,
@@ -144,6 +138,7 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
@@ -151,13 +146,16 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
"""Ensure advanced params has a seed set for distributed consistency."""
if params is None:
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
if params.seed is None:
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
return params
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
@@ -206,7 +204,7 @@ class API:
self.app.mount(
"/",
StaticFiles(
directory=DASHBOARD_DIR,
directory=find_dashboard(),
html=True,
),
name="dashboard",
@@ -271,9 +269,6 @@ class API:
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
self.app.delete("/models/custom/{model_id:path}")(self.delete_custom_model)
self.app.get("/models/search")(self.search_models)
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
@@ -386,7 +381,10 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
model_card = await ModelCard.load(model_id)
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []
for sharding in (Sharding.Pipeline, Sharding.Tensor):
for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):
@@ -401,93 +399,96 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(sorted(placement_node_ids, key=str)):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
)
)
seen.add(
(
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
)
return PlacementPreviewResponse(previews=previews)
@@ -627,11 +628,6 @@ class API:
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await collect_chat_response(
@@ -656,21 +652,23 @@ class API:
response = await self._collect_text_generation_with_stats(command.command_id)
return response
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
async def _resolve_and_validate_text_model(self, model: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(model)
resolved = model_card.model_id
if not any(
instance.shard_assignments.model_id == model_id
instance.shard_assignments.model_id == resolved
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(model_id)
await self._trigger_notify_user_to_download_model(resolved)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {model_id}",
detail=f"No instance found for model {resolved}",
)
return model_id
return resolved
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
@@ -724,9 +722,6 @@ class API:
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
task_params=payload,
@@ -975,9 +970,6 @@ class API:
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
task_params=payload,
@@ -1009,7 +1001,6 @@ class API:
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
advanced_params = _ensure_seed(advanced_params)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
@@ -1188,11 +1179,6 @@ class API:
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await collect_claude_response(
@@ -1220,11 +1206,6 @@ class API:
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await collect_responses_response(
@@ -1255,105 +1236,35 @@ class API:
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id),
family=card.family,
quantization=card.quantization,
base_model=card.base_model,
capabilities=card.capabilities,
)
for card in await get_model_cards()
for card in MODEL_CARDS.values()
]
)
async def add_custom_model(self, payload: AddCustomModelParams) -> ModelListModel:
"""Fetch a model from HuggingFace and save as a custom model card."""
try:
card = await ModelCard.fetch_from_hf(payload.model_id)
except Exception as exc:
raise HTTPException(
status_code=400, detail=f"Failed to fetch model: {exc}"
) from exc
return ModelListModel(
id=card.model_id,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=True,
)
async def delete_custom_model(self, model_id: ModelId) -> JSONResponse:
"""Delete a user-added custom model card."""
deleted = await delete_custom_card(model_id)
if not deleted:
raise HTTPException(status_code=404, detail="Custom model card not found")
return JSONResponse(
{"message": "Model card deleted", "model_id": str(model_id)}
)
async def search_models(
self, query: str = "", limit: int = 20
) -> list[HuggingFaceSearchResult]:
"""Search HuggingFace Hub for mlx-community models."""
from huggingface_hub import list_models
results = list_models(
search=query or None,
author="mlx-community",
sort="downloads",
limit=limit,
)
return [
HuggingFaceSearchResult(
id=m.id,
author=m.author or "",
downloads=m.downloads or 0,
likes=m.likes or 0,
last_modified=str(m.last_modified or ""),
tags=list(m.tags or []),
)
for m in results
]
async def run(self):
shutdown_ev = anyio.Event()
try:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
tg.start_soon(self.run_api, shutdown_ev)
try:
await anyio.sleep_forever()
finally:
with anyio.CancelScope(shield=True):
shutdown_ev.set()
finally:
self.command_sender.close()
self.global_event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
cfg.bind = [f"0.0.0.0:{self.port}"]
cfg.bind = f"0.0.0.0:{self.port}"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = None
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
with anyio.CancelScope(shield=True):
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
await serve(
cast(ASGIFramework, self.app),
cfg,
shutdown_trigger=ev.wait,
shutdown_trigger=lambda: anyio.sleep_forever(),
)
self.command_sender.close()
self.global_event_receiver.close()
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -96,18 +96,16 @@ class Master:
async def run(self):
logger.info("Starting Master")
try:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")

View File

@@ -10,7 +10,6 @@ from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
Sharding,
ShardMetadata,
@@ -75,43 +74,40 @@ def allocate_layers_proportionally(
return result
def _validate_cycle(cycle: Cycle) -> None:
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
):
if not cycle.node_ids:
raise ValueError("Cannot create shard assignments for empty node cycle")
def _compute_total_memory(
node_ids: list[NodeId],
node_memory: Mapping[NodeId, MemoryUsage],
) -> Memory:
total_memory = sum(
(node_memory[node_id].ram_available for node_id in node_ids),
cycle_memory = sum(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if total_memory.in_bytes == 0:
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
return total_memory
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
def _allocate_and_validate_layers(
node_ids: list[NodeId],
node_memory: Mapping[NodeId, MemoryUsage],
total_memory: Memory,
model_card: ModelCard,
) -> list[int]:
layer_allocations = allocate_layers_proportionally(
total_layers=model_card.n_layers,
total_layers=total_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
for node_id in node_ids
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
total_storage_bytes = model_card.storage_size.in_bytes
total_layers = model_card.n_layers
for i, node_id in enumerate(node_ids):
node_layers = layer_allocations[i]
required_memory = (total_storage_bytes * node_layers) // total_layers
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
@@ -120,126 +116,33 @@ def _allocate_and_validate_layers(
f"but only has {available_memory / (1024**3):.2f} GB available"
)
return layer_allocations
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for pipeline parallel execution."""
world_size = len(cycle)
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
if use_cfg_parallel:
return _get_shard_assignments_for_cfg_parallel(model_card, cycle, node_memory)
else:
return _get_shard_assignments_for_pure_pipeline(model_card, cycle, node_memory)
def _get_shard_assignments_for_cfg_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for CFG parallel execution.
CFG parallel runs two independent pipelines. Group 0 processes the positive
prompt, group 1 processes the negative prompt. The ring topology places
group 1's ranks in reverse order so both "last stages" are neighbors for
efficient CFG exchange.
"""
_validate_cycle(cycle)
world_size = len(cycle)
cfg_world_size = 2
pipeline_world_size = world_size // cfg_world_size
# Allocate layers for one pipeline group (both groups run the same layers)
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
pipeline_memory = _compute_total_memory(pipeline_node_ids, node_memory)
layer_allocations = _allocate_and_validate_layers(
pipeline_node_ids, node_memory, pipeline_memory, model_card
)
# Ring topology: group 0 ascending [0,1,2,...], group 1 descending [...,2,1,0]
# This places both last stages as neighbors for CFG exchange.
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
(1, r) for r in reversed(range(pipeline_world_size))
]
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for device_rank, node_id in enumerate(cycle.node_ids):
cfg_rank, pipeline_rank = position_to_cfg_pipeline[device_rank]
layers_before = sum(layer_allocations[:pipeline_rank])
node_layers = layer_allocations[pipeline_rank]
shard = CfgShardMetadata(
model_card=model_card,
device_rank=device_rank,
world_size=world_size,
start_layer=layers_before,
end_layer=layers_before + node_layers,
n_layers=model_card.n_layers,
cfg_rank=cfg_rank,
cfg_world_size=cfg_world_size,
pipeline_rank=pipeline_rank,
pipeline_world_size=pipeline_world_size,
)
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
return ShardAssignments(
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
def _get_shard_assignments_for_pure_pipeline(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for pure pipeline execution."""
_validate_cycle(cycle)
total_memory = _compute_total_memory(cycle.node_ids, node_memory)
layer_allocations = _allocate_and_validate_layers(
cycle.node_ids, node_memory, total_memory, model_card
)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for pipeline_rank, node_id in enumerate(cycle.node_ids):
layers_before = sum(layer_allocations[:pipeline_rank])
node_layers = layer_allocations[pipeline_rank]
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=pipeline_rank,
world_size=len(cycle),
start_layer=layers_before,
end_layer=layers_before + node_layers,
n_layers=model_card.n_layers,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
end_layer=layers_assigned + node_layers,
n_layers=total_layers,
)
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
layers_assigned += node_layers
return ShardAssignments(
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
return shard_assignments
def get_shard_assignments_for_tensor_parallel(
model_card: ModelCard,

Some files were not shown because too many files have changed in this diff Show More