mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-17 18:41:49 -05:00
Compare commits
1 Commits
alexcheema
...
model-card
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2f9f50f7e |
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,6 +863,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -902,6 +903,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,6 +1520,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1527,6 +1530,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1939,6 +1943,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2646,6 +2651,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2833,6 +2839,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2977,6 +2984,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2998,6 +3006,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -60,39 +60,12 @@
|
||||
return models;
|
||||
});
|
||||
|
||||
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
|
||||
let previousModelIds: Set<string> = new Set();
|
||||
|
||||
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
|
||||
// Auto-select the first available model if none is selected
|
||||
$effect(() => {
|
||||
const models = availableModels();
|
||||
const currentModelIds = new Set(models.map(m => m.id));
|
||||
|
||||
if (models.length > 0) {
|
||||
// Find newly added models (in current but not in previous)
|
||||
const newModels = models.filter(m => !previousModelIds.has(m.id));
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some(m => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
if (models.length > 0 && !currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
|
||||
// Update previous model IDs for next comparison
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
messages,
|
||||
currentResponse,
|
||||
import {
|
||||
messages,
|
||||
currentResponse,
|
||||
isLoading,
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse,
|
||||
regenerateFromToken
|
||||
regenerateLastResponse
|
||||
} from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
import TokenHeatmap from './TokenHeatmap.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -97,23 +95,6 @@
|
||||
let copiedMessageId = $state<string | null>(null);
|
||||
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
// Uncertainty view state - tracks which messages show token heatmap
|
||||
let uncertaintyViewMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
function toggleUncertaintyView(messageId: string) {
|
||||
const newSet = new Set(uncertaintyViewMessageIds);
|
||||
if (newSet.has(messageId)) {
|
||||
newSet.delete(messageId);
|
||||
} else {
|
||||
newSet.add(messageId);
|
||||
}
|
||||
uncertaintyViewMessageIds = newSet;
|
||||
}
|
||||
|
||||
function isUncertaintyViewEnabled(messageId: string): boolean {
|
||||
return uncertaintyViewMessageIds.has(messageId);
|
||||
}
|
||||
|
||||
function formatTimestamp(timestamp: number): string {
|
||||
return new Date(timestamp).toLocaleTimeString('en-US', {
|
||||
hour12: false,
|
||||
@@ -385,17 +366,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-xs text-foreground">
|
||||
{#if message.role === 'assistant' && isUncertaintyViewEnabled(message.id) && message.tokens && message.tokens.length > 0}
|
||||
<!-- Uncertainty heatmap view -->
|
||||
<TokenHeatmap
|
||||
tokens={message.tokens}
|
||||
isGenerating={loading}
|
||||
onRegenerateFrom={(tokenIndex) => regenerateFromToken(message.id, tokenIndex)}
|
||||
/>
|
||||
{:else}
|
||||
<!-- Normal markdown view -->
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{/if}
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
@@ -448,19 +419,6 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Uncertainty view toggle (assistant messages with tokens only) -->
|
||||
{#if message.role === 'assistant' && message.tokens && message.tokens.length > 0}
|
||||
<button
|
||||
onclick={() => toggleUncertaintyView(message.id)}
|
||||
class="p-1.5 transition-colors rounded cursor-pointer {isUncertaintyViewEnabled(message.id) ? 'text-exo-yellow' : 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
title={isUncertaintyViewEnabled(message.id) ? 'Hide uncertainty' : 'Show uncertainty'}
|
||||
>
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { TokenData } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
tokens: TokenData[];
|
||||
class?: string;
|
||||
isGenerating?: boolean;
|
||||
onRegenerateFrom?: (tokenIndex: number) => void;
|
||||
}
|
||||
|
||||
let { tokens, class: className = '', isGenerating = false, onRegenerateFrom }: Props = $props();
|
||||
|
||||
// Tooltip state - track both token data and index
|
||||
let hoveredTokenIndex = $state<number | null>(null);
|
||||
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
|
||||
let isTooltipHovered = $state(false);
|
||||
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
// Derive the hovered token from the index (stable across re-renders)
|
||||
const hoveredToken = $derived(
|
||||
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
|
||||
? { token: tokens[hoveredTokenIndex], index: hoveredTokenIndex, ...hoveredPosition }
|
||||
: null
|
||||
);
|
||||
|
||||
/**
|
||||
* Get confidence styling based on probability.
|
||||
* Following Apple design principles: high confidence tokens blend in,
|
||||
* only uncertainty draws attention.
|
||||
*/
|
||||
function getConfidenceClass(probability: number): string {
|
||||
if (probability > 0.8) return 'text-inherit'; // Expected tokens - blend in
|
||||
if (probability > 0.5) return 'bg-gray-500/10 text-inherit'; // Slight hint
|
||||
if (probability > 0.2) return 'bg-amber-500/15 text-amber-200/90'; // Subtle warmth
|
||||
return 'bg-red-500/20 text-red-200/90'; // Draws attention
|
||||
}
|
||||
|
||||
/**
|
||||
* Get border/underline styling for uncertain tokens
|
||||
*/
|
||||
function getBorderClass(probability: number): string {
|
||||
if (probability > 0.8) return 'border-transparent'; // No border for expected
|
||||
if (probability > 0.5) return 'border-gray-500/20';
|
||||
if (probability > 0.2) return 'border-amber-500/30';
|
||||
return 'border-red-500/40';
|
||||
}
|
||||
|
||||
function clearHideTimeout() {
|
||||
if (hideTimeoutId) {
|
||||
clearTimeout(hideTimeoutId);
|
||||
hideTimeoutId = null;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseEnter(event: MouseEvent, token: TokenData, index: number) {
|
||||
clearHideTimeout();
|
||||
const rect = (event.target as HTMLElement).getBoundingClientRect();
|
||||
hoveredTokenIndex = index;
|
||||
hoveredPosition = {
|
||||
x: rect.left + rect.width / 2,
|
||||
y: rect.top - 10
|
||||
};
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
clearHideTimeout();
|
||||
// Use longer delay during generation to account for re-renders
|
||||
const delay = isGenerating ? 300 : 100;
|
||||
hideTimeoutId = setTimeout(() => {
|
||||
if (!isTooltipHovered) {
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
}, delay);
|
||||
}
|
||||
|
||||
function handleTooltipEnter() {
|
||||
clearHideTimeout();
|
||||
isTooltipHovered = true;
|
||||
}
|
||||
|
||||
function handleTooltipLeave() {
|
||||
isTooltipHovered = false;
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
|
||||
function handleRegenerate() {
|
||||
if (hoveredToken && onRegenerateFrom) {
|
||||
const indexToRegenerate = hoveredToken.index;
|
||||
// Clear hover state immediately
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
isTooltipHovered = false;
|
||||
// Call regenerate
|
||||
onRegenerateFrom(indexToRegenerate);
|
||||
}
|
||||
}
|
||||
|
||||
function formatProbability(prob: number): string {
|
||||
return (prob * 100).toFixed(1) + '%';
|
||||
}
|
||||
|
||||
function formatLogprob(logprob: number): string {
|
||||
return logprob.toFixed(3);
|
||||
}
|
||||
|
||||
function getProbabilityColor(probability: number): string {
|
||||
if (probability > 0.8) return 'text-gray-300';
|
||||
if (probability > 0.5) return 'text-gray-400';
|
||||
if (probability > 0.2) return 'text-amber-400';
|
||||
return 'text-red-400';
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="token-heatmap leading-relaxed {className}">
|
||||
{#each tokens as tokenData, i (i)}
|
||||
<span
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(tokenData.probability)} {getBorderClass(tokenData.probability)} hover:opacity-80"
|
||||
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
|
||||
onmouseleave={handleMouseLeave}
|
||||
>{tokenData.token}</span>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Tooltip -->
|
||||
{#if hoveredToken}
|
||||
<div
|
||||
class="fixed z-50"
|
||||
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
|
||||
onmouseenter={handleTooltipEnter}
|
||||
onmouseleave={handleTooltipLeave}
|
||||
>
|
||||
<div class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48">
|
||||
<!-- Token info -->
|
||||
<div class="mb-2">
|
||||
<span class="text-gray-500 text-xs">Token:</span>
|
||||
<span class="text-white font-mono ml-1">"{hoveredToken.token.token}"</span>
|
||||
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2">{formatProbability(hoveredToken.token.probability)}</span>
|
||||
</div>
|
||||
|
||||
<div class="text-gray-400 text-xs mb-1">
|
||||
logprob: <span class="text-gray-300 font-mono">{formatLogprob(hoveredToken.token.logprob)}</span>
|
||||
</div>
|
||||
|
||||
<!-- Top alternatives -->
|
||||
{#if hoveredToken.token.topLogprobs.length > 0}
|
||||
<div class="border-t border-gray-700/50 mt-2 pt-2">
|
||||
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
|
||||
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
|
||||
{@const altProb = Math.exp(alt.logprob)}
|
||||
<div class="flex justify-between items-center text-xs py-0.5">
|
||||
<span class="text-gray-300 font-mono truncate max-w-24">"{alt.token}"</span>
|
||||
<span class="text-gray-400 ml-2">{formatProbability(altProb)}</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button -->
|
||||
{#if onRegenerateFrom}
|
||||
<button
|
||||
onclick={handleRegenerate}
|
||||
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
|
||||
>
|
||||
<svg class="w-3 h-3" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
|
||||
</svg>
|
||||
Regenerate from here
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
<!-- Arrow -->
|
||||
<div class="absolute left-1/2 -translate-x-1/2 top-full">
|
||||
<div class="border-8 border-transparent border-t-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.token-heatmap {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.token-span {
|
||||
margin: 0;
|
||||
border-width: 1px;
|
||||
}
|
||||
</style>
|
||||
@@ -182,20 +182,6 @@ export interface MessageAttachment {
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
// Token-level data for uncertainty visualization
|
||||
export interface TopLogprob {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes?: number[];
|
||||
}
|
||||
|
||||
export interface TokenData {
|
||||
token: string;
|
||||
logprob: number;
|
||||
probability: number; // exp(logprob)
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -205,7 +191,6 @@ export interface Message {
|
||||
attachments?: MessageAttachment[];
|
||||
ttftMs?: number; // Time to first token in ms (for assistant messages)
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
tokens?: TokenData[]; // Token-level data for uncertainty visualization
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -383,21 +368,6 @@ class AppStore {
|
||||
private fetchInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
private currentRequestController: AbortController | null = null;
|
||||
|
||||
/**
|
||||
* Abort any in-flight generation request
|
||||
*/
|
||||
abortCurrentRequest(): boolean {
|
||||
if (this.currentRequestController) {
|
||||
this.currentRequestController.abort();
|
||||
this.currentRequestController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
@@ -1076,10 +1046,6 @@ class AppStore {
|
||||
// Remove any messages after the user message
|
||||
this.messages = this.messages.slice(0, lastUserIndex + 1);
|
||||
|
||||
// Create abort controller for this request
|
||||
const controller = new AbortController();
|
||||
this.currentRequestController = controller;
|
||||
|
||||
// Resend the message to get a new response
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
@@ -1141,10 +1107,7 @@ class AppStore {
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -1177,7 +1140,6 @@ class AppStore {
|
||||
const decoder = new TextDecoder();
|
||||
let fullContent = "";
|
||||
let partialLine = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
@@ -1196,29 +1158,6 @@ class AppStore {
|
||||
const json = JSON.parse(trimmed.slice(6));
|
||||
const delta = json.choices?.[0]?.delta?.content;
|
||||
if (delta) {
|
||||
// Extract logprobs for uncertainty visualization
|
||||
const logprobsData = json.choices?.[0]?.logprobs;
|
||||
if (logprobsData?.content?.[0]) {
|
||||
const logprobItem = logprobsData.content[0];
|
||||
const tokenData: TokenData = {
|
||||
token: logprobItem.token || delta,
|
||||
logprob: logprobItem.logprob ?? 0,
|
||||
probability: Math.exp(logprobItem.logprob ?? 0),
|
||||
topLogprobs: (logprobItem.top_logprobs || []).map(
|
||||
(item: {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes?: number[];
|
||||
}) => ({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
bytes: item.bytes,
|
||||
}),
|
||||
),
|
||||
};
|
||||
collectedTokens.push(tokenData);
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
@@ -1231,7 +1170,6 @@ class AppStore {
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
this.messages[idx].tokens = [...collectedTokens];
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
@@ -1249,16 +1187,9 @@ class AppStore {
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
if (collectedTokens.length > 0) {
|
||||
this.messages[idx].tokens = collectedTokens;
|
||||
}
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} catch (error) {
|
||||
// Don't show error for aborted requests (user cancelled)
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
const idx = this.messages.findIndex((m) => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content =
|
||||
@@ -1266,10 +1197,6 @@ class AppStore {
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} finally {
|
||||
// Clean up controller if this is still the active request
|
||||
if (this.currentRequestController === controller) {
|
||||
this.currentRequestController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.updateActiveConversation();
|
||||
@@ -1291,210 +1218,6 @@ class AppStore {
|
||||
this.tps = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Regenerate from a specific token in an assistant message.
|
||||
* Keeps content up to and including the specified token, then continues generation.
|
||||
* If a generation is already in progress, it will be aborted first.
|
||||
*/
|
||||
async regenerateFromToken(
|
||||
messageId: string,
|
||||
tokenIndex: number,
|
||||
): Promise<void> {
|
||||
// Abort any in-flight request first
|
||||
this.abortCurrentRequest();
|
||||
|
||||
const messageIdx = this.messages.findIndex((m) => m.id === messageId);
|
||||
if (messageIdx === -1) return;
|
||||
|
||||
const message = this.messages[messageIdx];
|
||||
if (message.role !== "assistant" || !message.tokens) return;
|
||||
|
||||
// Get tokens up to and including the specified index
|
||||
const keptTokens = message.tokens.slice(0, tokenIndex + 1);
|
||||
const prefixText = keptTokens.map((t) => t.token).join("");
|
||||
|
||||
// Update the message with just the prefix
|
||||
this.messages[messageIdx].content = prefixText;
|
||||
this.messages[messageIdx].tokens = keptTokens;
|
||||
this.messages[messageIdx].thinking = undefined;
|
||||
this.persistActiveConversation();
|
||||
|
||||
// Start loading
|
||||
this.isLoading = true;
|
||||
this.currentResponse = prefixText;
|
||||
|
||||
// Create abort controller for this request
|
||||
const controller = new AbortController();
|
||||
this.currentRequestController = controller;
|
||||
|
||||
try {
|
||||
const systemPrompt = {
|
||||
role: "system" as const,
|
||||
content:
|
||||
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
|
||||
};
|
||||
|
||||
// Build messages: all messages before this one, plus the prefix as assistant
|
||||
const apiMessages: { role: string; content: string }[] = [systemPrompt];
|
||||
for (let i = 0; i < messageIdx; i++) {
|
||||
const m = this.messages[i];
|
||||
apiMessages.push({ role: m.role, content: m.content || "" });
|
||||
}
|
||||
// Add the prefix as a partial assistant response to continue from
|
||||
apiMessages.push({ role: "assistant", content: prefixText });
|
||||
|
||||
// Determine which model to use
|
||||
let modelToUse = this.selectedChatModel;
|
||||
if (!modelToUse) {
|
||||
const firstInstanceKey = Object.keys(this.instances)[0];
|
||||
if (firstInstanceKey) {
|
||||
const instance = this.instances[firstInstanceKey] as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
if (instance) {
|
||||
const keys = Object.keys(instance);
|
||||
if (keys.length === 1) {
|
||||
const inst = instance[keys[0]] as
|
||||
| { shardAssignments?: { modelId?: string } }
|
||||
| undefined;
|
||||
modelToUse = inst?.shardAssignments?.modelId || "";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!modelToUse) {
|
||||
this.messages[messageIdx].content =
|
||||
prefixText + "\n\nError: No model available.";
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
continue_from_prefix: true,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
this.messages[messageIdx].content =
|
||||
prefixText + `\n\nError: ${response.status} - ${errorText}`;
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
this.messages[messageIdx].content =
|
||||
prefixText + "\n\nError: No response stream available";
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let fullContent = prefixText;
|
||||
let partialLine = "";
|
||||
const collectedTokens: TokenData[] = [...keptTokens];
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
const lines = (partialLine + chunk).split("\n");
|
||||
partialLine = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || trimmed === "data: [DONE]") continue;
|
||||
|
||||
if (trimmed.startsWith("data: ")) {
|
||||
try {
|
||||
const json = JSON.parse(trimmed.slice(6));
|
||||
const delta = json.choices?.[0]?.delta?.content;
|
||||
if (delta) {
|
||||
// Extract logprobs for uncertainty visualization
|
||||
const logprobsData = json.choices?.[0]?.logprobs;
|
||||
if (logprobsData?.content?.[0]) {
|
||||
const logprobItem = logprobsData.content[0];
|
||||
const tokenData: TokenData = {
|
||||
token: logprobItem.token || delta,
|
||||
logprob: logprobItem.logprob ?? 0,
|
||||
probability: Math.exp(logprobItem.logprob ?? 0),
|
||||
topLogprobs: (logprobItem.top_logprobs || []).map(
|
||||
(item: {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes?: number[];
|
||||
}) => ({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
bytes: item.bytes,
|
||||
}),
|
||||
),
|
||||
};
|
||||
collectedTokens.push(tokenData);
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
this.currentResponse = displayContent;
|
||||
|
||||
this.messages[messageIdx].content = displayContent;
|
||||
this.messages[messageIdx].thinking =
|
||||
thinkingContent || undefined;
|
||||
this.messages[messageIdx].tokens = [...collectedTokens];
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
} catch {
|
||||
// Skip malformed JSON
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final cleanup
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
this.messages[messageIdx].content = displayContent;
|
||||
this.messages[messageIdx].thinking = thinkingContent || undefined;
|
||||
if (collectedTokens.length > 0) {
|
||||
this.messages[messageIdx].tokens = collectedTokens;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} catch (error) {
|
||||
// Don't show error for aborted requests (user cancelled)
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
this.messages[messageIdx].content =
|
||||
prefixText +
|
||||
`\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
|
||||
this.persistActiveConversation();
|
||||
} finally {
|
||||
// Clean up controller if this is still the active request
|
||||
if (this.currentRequestController === controller) {
|
||||
this.currentRequestController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.updateActiveConversation();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Strip thinking tags from content for display.
|
||||
* Handles both complete <think>...</think> blocks and in-progress <think>... blocks during streaming.
|
||||
@@ -1551,10 +1274,6 @@ class AppStore {
|
||||
this.startChat();
|
||||
}
|
||||
|
||||
// Create abort controller for this request
|
||||
const controller = new AbortController();
|
||||
this.currentRequestController = controller;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
this.ttftMs = null;
|
||||
@@ -1689,10 +1408,7 @@ class AppStore {
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -1708,7 +1424,6 @@ class AppStore {
|
||||
const decoder = new TextDecoder();
|
||||
let fullContent = "";
|
||||
let buffer = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
@@ -1748,29 +1463,6 @@ class AppStore {
|
||||
this.tps = (tokenCount / elapsed) * 1000;
|
||||
}
|
||||
|
||||
// Extract logprobs for uncertainty visualization
|
||||
const logprobsData = parsed.choices?.[0]?.logprobs;
|
||||
if (logprobsData?.content?.[0]) {
|
||||
const logprobItem = logprobsData.content[0];
|
||||
const tokenData: TokenData = {
|
||||
token: logprobItem.token || tokenContent,
|
||||
logprob: logprobItem.logprob ?? 0,
|
||||
probability: Math.exp(logprobItem.logprob ?? 0),
|
||||
topLogprobs: (logprobItem.top_logprobs || []).map(
|
||||
(item: {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes?: number[];
|
||||
}) => ({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
bytes: item.bytes,
|
||||
}),
|
||||
),
|
||||
};
|
||||
collectedTokens.push(tokenData);
|
||||
}
|
||||
|
||||
fullContent += tokenContent;
|
||||
|
||||
// Strip thinking tags for display and extract thinking content
|
||||
@@ -1785,8 +1477,6 @@ class AppStore {
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
// Update tokens during streaming for real-time visualization
|
||||
this.messages[idx].tokens = [...collectedTokens];
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
@@ -1834,17 +1524,9 @@ class AppStore {
|
||||
if (this.tps !== null) {
|
||||
this.messages[idx].tps = this.tps;
|
||||
}
|
||||
// Store token data for uncertainty visualization
|
||||
if (collectedTokens.length > 0) {
|
||||
this.messages[idx].tokens = collectedTokens;
|
||||
}
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} catch (error) {
|
||||
// Don't show error for aborted requests (user cancelled)
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
console.error("Error sending message:", error);
|
||||
// Update the assistant message with error
|
||||
const idx = this.messages.findIndex((m) => m.id === assistantMessage.id);
|
||||
@@ -1854,10 +1536,6 @@ class AppStore {
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} finally {
|
||||
// Clean up controller if this is still the active request
|
||||
if (this.currentRequestController === controller) {
|
||||
this.currentRequestController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.updateActiveConversation();
|
||||
@@ -1937,9 +1615,6 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
|
||||
appStore.regenerateFromToken(messageId, tokenIndex);
|
||||
export const abortCurrentRequest = () => appStore.abortCurrentRequest();
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
|
||||
@@ -400,8 +400,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
// Auto-select the launched model only if no model is currently selected
|
||||
if (!selectedChatModel()) {
|
||||
setSelectedChatModel(modelId);
|
||||
}
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
@@ -761,10 +763,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
async function deleteInstance(instanceId: string) {
|
||||
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
|
||||
|
||||
// Get the model ID of the instance being deleted before we delete it
|
||||
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
|
||||
const wasSelected = selectedChatModel() === deletedInstanceModelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/instance/${instanceId}`, {
|
||||
method: 'DELETE',
|
||||
@@ -773,24 +771,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to delete instance:', response.status);
|
||||
} else if (wasSelected) {
|
||||
// If we deleted the currently selected model, switch to another available model
|
||||
// Find another instance that isn't the one we just deleted
|
||||
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
|
||||
if (remainingInstances.length > 0) {
|
||||
// Select the last instance (most recently added, since objects preserve insertion order)
|
||||
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting instance:', error);
|
||||
|
||||
2
justfile
2
justfile
@@ -1,5 +1,3 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"tomlkit>=0.14.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
15
resources/model_cards/deepseek-v3.1-4bit.toml
Normal file
15
resources/model_cards/deepseek-v3.1-4bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "deepseek-v3.1-4bit"
|
||||
model_id = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
name = "DeepSeek V3.1 (4-bit)"
|
||||
description = "DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
pretty_name = "DeepSeek V3.1 (4-bit)"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 405874409472
|
||||
15
resources/model_cards/deepseek-v3.1-8bit.toml
Normal file
15
resources/model_cards/deepseek-v3.1-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "deepseek-v3.1-8bit"
|
||||
model_id = "mlx-community/DeepSeek-V3.1-8bit"
|
||||
name = "DeepSeek V3.1 (8-bit)"
|
||||
description = "DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/DeepSeek-V3.1-8bit"
|
||||
pretty_name = "DeepSeek V3.1 (8-bit)"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 765577920512
|
||||
15
resources/model_cards/glm-4.5-air-8bit.toml
Normal file
15
resources/model_cards/glm-4.5-air-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "glm-4.5-air-8bit"
|
||||
model_id = "mlx-community/GLM-4.5-Air-8bit"
|
||||
name = "GLM 4.5 Air 8bit"
|
||||
description = "GLM 4.5 Air 8bit"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/GLM-4.5-Air-8bit"
|
||||
pretty_name = "GLM 4.5 Air 8bit"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = false
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 122406567936
|
||||
15
resources/model_cards/glm-4.5-air-bf16.toml
Normal file
15
resources/model_cards/glm-4.5-air-bf16.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "glm-4.5-air-bf16"
|
||||
model_id = "mlx-community/GLM-4.5-Air-bf16"
|
||||
name = "GLM 4.5 Air bf16"
|
||||
description = "GLM 4.5 Air bf16"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/GLM-4.5-Air-bf16"
|
||||
pretty_name = "GLM 4.5 Air bf16"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 229780750336
|
||||
15
resources/model_cards/glm-4.7-4bit.toml
Normal file
15
resources/model_cards/glm-4.7-4bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "glm-4.7-4bit"
|
||||
model_id = "mlx-community/GLM-4.7-4bit"
|
||||
name = "GLM 4.7 4bit"
|
||||
description = "GLM 4.7 4bit"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/GLM-4.7-4bit"
|
||||
pretty_name = "GLM 4.7 4bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 198556925568
|
||||
15
resources/model_cards/glm-4.7-6bit.toml
Normal file
15
resources/model_cards/glm-4.7-6bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "glm-4.7-6bit"
|
||||
model_id = "mlx-community/GLM-4.7-6bit"
|
||||
name = "GLM 4.7 6bit"
|
||||
description = "GLM 4.7 6bit"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/GLM-4.7-6bit"
|
||||
pretty_name = "GLM 4.7 6bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 286737579648
|
||||
15
resources/model_cards/glm-4.7-8bit-gs32.toml
Normal file
15
resources/model_cards/glm-4.7-8bit-gs32.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "glm-4.7-8bit-gs32"
|
||||
model_id = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
name = "GLM 4.7 8bit (gs32)"
|
||||
description = "GLM 4.7 8bit (gs32)"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
pretty_name = "GLM 4.7 8bit (gs32)"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 396963397248
|
||||
15
resources/model_cards/gpt-oss-120b-MXFP4-Q8.toml
Normal file
15
resources/model_cards/gpt-oss-120b-MXFP4-Q8.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "gpt-oss-120b-MXFP4-Q8"
|
||||
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
name = "GPT-OSS 120B (MXFP4-Q8, MLX)"
|
||||
description = "OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
pretty_name = "GPT-OSS 120B (MXFP4-Q8, MLX)"
|
||||
n_layers = 36
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 70652212224
|
||||
15
resources/model_cards/gpt-oss-20b-4bit.toml
Normal file
15
resources/model_cards/gpt-oss-20b-4bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "gpt-oss-20b-4bit"
|
||||
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q4"
|
||||
name = "GPT-OSS 20B (MXFP4-Q4, MLX)"
|
||||
description = "OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q4"
|
||||
pretty_name = "GPT-OSS 20B (MXFP4-Q4, MLX)"
|
||||
n_layers = 24
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 12025908224
|
||||
15
resources/model_cards/kimi-k2-instruct-4bit.toml
Normal file
15
resources/model_cards/kimi-k2-instruct-4bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "kimi-k2-instruct-4bit"
|
||||
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
name = "Kimi K2 Instruct (4-bit)"
|
||||
description = "Kimi K2 is a large language model trained on the Kimi K2 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
pretty_name = "Kimi K2 Instruct (4-bit)"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 620622774272
|
||||
15
resources/model_cards/kimi-k2-thinking.toml
Normal file
15
resources/model_cards/kimi-k2-thinking.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "kimi-k2-thinking"
|
||||
model_id = "mlx-community/Kimi-K2-Thinking"
|
||||
name = "Kimi K2 Thinking (4-bit)"
|
||||
description = "Kimi K2 Thinking is the latest, most capable version of open-source thinking model."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Kimi-K2-Thinking"
|
||||
pretty_name = "Kimi K2 Thinking (4-bit)"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 706522120192
|
||||
15
resources/model_cards/llama-3.1-70b.toml
Normal file
15
resources/model_cards/llama-3.1-70b.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.1-70b"
|
||||
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
name = "Llama 3.1 70B (4-bit)"
|
||||
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
pretty_name = "Llama 3.1 70B (4-bit)"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 40652242944
|
||||
15
resources/model_cards/llama-3.1-8b-8bit.toml
Normal file
15
resources/model_cards/llama-3.1-8b-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.1-8b-8bit"
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
name = "Llama 3.1 8B (8-bit)"
|
||||
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
pretty_name = "Llama 3.1 8B (8-bit)"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 8954839040
|
||||
15
resources/model_cards/llama-3.1-8b-bf16.toml
Normal file
15
resources/model_cards/llama-3.1-8b-bf16.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.1-8b-bf16"
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
name = "Llama 3.1 8B (BF16)"
|
||||
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
pretty_name = "Llama 3.1 8B (BF16)"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 16882073600
|
||||
15
resources/model_cards/llama-3.1-8b.toml
Normal file
15
resources/model_cards/llama-3.1-8b.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.1-8b"
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
name = "Llama 3.1 8B (4-bit)"
|
||||
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
pretty_name = "Llama 3.1 8B (4-bit)"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 4637851648
|
||||
15
resources/model_cards/llama-3.2-1b.toml
Normal file
15
resources/model_cards/llama-3.2-1b.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.2-1b"
|
||||
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
name = "Llama 3.2 1B (4-bit)"
|
||||
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
pretty_name = "Llama 3.2 1B (4-bit)"
|
||||
n_layers = 16
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 729808896
|
||||
15
resources/model_cards/llama-3.2-3b-8bit.toml
Normal file
15
resources/model_cards/llama-3.2-3b-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.2-3b-8bit"
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
name = "Llama 3.2 3B (8-bit)"
|
||||
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
pretty_name = "Llama 3.2 3B (8-bit)"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 3501195264
|
||||
15
resources/model_cards/llama-3.2-3b.toml
Normal file
15
resources/model_cards/llama-3.2-3b.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.2-3b"
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
name = "Llama 3.2 3B (4-bit)"
|
||||
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
pretty_name = "Llama 3.2 3B (4-bit)"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 1863319552
|
||||
15
resources/model_cards/llama-3.3-70b-8bit.toml
Normal file
15
resources/model_cards/llama-3.3-70b-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.3-70b-8bit"
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
name = "Llama 3.3 70B (8-bit)"
|
||||
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
pretty_name = "Llama 3.3 70B (8-bit)"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 76799803392
|
||||
15
resources/model_cards/llama-3.3-70b-fp16.toml
Normal file
15
resources/model_cards/llama-3.3-70b-fp16.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.3-70b-fp16"
|
||||
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
name = "Llama 3.3 70B (FP16)"
|
||||
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
pretty_name = "Llama 3.3 70B (FP16)"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 144383672320
|
||||
15
resources/model_cards/llama-3.3-70b.toml
Normal file
15
resources/model_cards/llama-3.3-70b.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "llama-3.3-70b"
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
name = "Llama 3.3 70B (4-bit)"
|
||||
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
pretty_name = "Llama 3.3 70B"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 40652242944
|
||||
15
resources/model_cards/minimax-m2.1-3bit.toml
Normal file
15
resources/model_cards/minimax-m2.1-3bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "minimax-m2.1-3bit"
|
||||
model_id = "mlx-community/MiniMax-M2.1-3bit"
|
||||
name = "MiniMax M2.1 3bit"
|
||||
description = "MiniMax M2.1 3bit"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/MiniMax-M2.1-3bit"
|
||||
pretty_name = "MiniMax M2.1 3bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 100086644736
|
||||
15
resources/model_cards/minimax-m2.1-8bit.toml
Normal file
15
resources/model_cards/minimax-m2.1-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "minimax-m2.1-8bit"
|
||||
model_id = "mlx-community/MiniMax-M2.1-8bit"
|
||||
name = "MiniMax M2.1 8bit"
|
||||
description = "MiniMax M2.1 8bit"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/MiniMax-M2.1-8bit"
|
||||
pretty_name = "MiniMax M2.1 8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 242986745856
|
||||
15
resources/model_cards/qwen3-0.6b-8bit.toml
Normal file
15
resources/model_cards/qwen3-0.6b-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-0.6b-8bit"
|
||||
model_id = "mlx-community/Qwen3-0.6B-8bit"
|
||||
name = "Qwen3 0.6B (8-bit)"
|
||||
description = "Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-0.6B-8bit"
|
||||
pretty_name = "Qwen3 0.6B (8-bit)"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 698351616
|
||||
15
resources/model_cards/qwen3-0.6b.toml
Normal file
15
resources/model_cards/qwen3-0.6b.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-0.6b"
|
||||
model_id = "mlx-community/Qwen3-0.6B-4bit"
|
||||
name = "Qwen3 0.6B (4-bit)"
|
||||
description = "Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-0.6B-4bit"
|
||||
pretty_name = "Qwen3 0.6B (4-bit)"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 342884352
|
||||
15
resources/model_cards/qwen3-235b-a22b-4bit.toml
Normal file
15
resources/model_cards/qwen3-235b-a22b-4bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-235b-a22b-4bit"
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
name = "Qwen3 235B A22B (4-bit)"
|
||||
description = "Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
pretty_name = "Qwen3 235B A22B (4-bit)"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 141733920768
|
||||
15
resources/model_cards/qwen3-235b-a22b-8bit.toml
Normal file
15
resources/model_cards/qwen3-235b-a22b-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-235b-a22b-8bit"
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
name = "Qwen3 235B A22B (8-bit)"
|
||||
description = "Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
pretty_name = "Qwen3 235B A22B (8-bit)"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 268435456000
|
||||
15
resources/model_cards/qwen3-30b-8bit.toml
Normal file
15
resources/model_cards/qwen3-30b-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-30b-8bit"
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
name = "Qwen3 30B A3B (8-bit)"
|
||||
description = "Qwen3 30B is a large language model trained on the Qwen3 30B dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
pretty_name = "Qwen3 30B A3B (8-bit)"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 33279705088
|
||||
15
resources/model_cards/qwen3-30b.toml
Normal file
15
resources/model_cards/qwen3-30b.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-30b"
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
name = "Qwen3 30B A3B (4-bit)"
|
||||
description = "Qwen3 30B is a large language model trained on the Qwen3 30B dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
pretty_name = "Qwen3 30B A3B (4-bit)"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 17612931072
|
||||
15
resources/model_cards/qwen3-80b-a3B-4bit.toml
Normal file
15
resources/model_cards/qwen3-80b-a3B-4bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-80b-a3B-4bit"
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
name = "Qwen3 80B A3B (4-bit)"
|
||||
description = "Qwen3 80B"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
pretty_name = "Qwen3 80B A3B (4-bit)"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 46976204800
|
||||
15
resources/model_cards/qwen3-80b-a3B-8bit.toml
Normal file
15
resources/model_cards/qwen3-80b-a3B-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-80b-a3B-8bit"
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
name = "Qwen3 80B A3B (8-bit)"
|
||||
description = "Qwen3 80B"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
pretty_name = "Qwen3 80B A3B (8-bit)"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 88814387200
|
||||
15
resources/model_cards/qwen3-80b-a3B-thinking-4bit.toml
Normal file
15
resources/model_cards/qwen3-80b-a3B-thinking-4bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-80b-a3B-thinking-4bit"
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
name = "Qwen3 80B A3B Thinking (4-bit)"
|
||||
description = "Qwen3 80B Reasoning model"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
pretty_name = "Qwen3 80B A3B (4-bit)"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 88814387200
|
||||
15
resources/model_cards/qwen3-80b-a3B-thinking-8bit.toml
Normal file
15
resources/model_cards/qwen3-80b-a3B-thinking-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-80b-a3B-thinking-8bit"
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
name = "Qwen3 80B A3B Thinking (8-bit)"
|
||||
description = "Qwen3 80B Reasoning model"
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
pretty_name = "Qwen3 80B A3B (8-bit)"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 88814387200
|
||||
15
resources/model_cards/qwen3-coder-480b-a35b-4bit.toml
Normal file
15
resources/model_cards/qwen3-coder-480b-a35b-4bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-coder-480b-a35b-4bit"
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
name = "Qwen3 Coder 480B A35B (4-bit)"
|
||||
description = "Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
pretty_name = "Qwen3 Coder 480B A35B (4-bit)"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 289910292480
|
||||
15
resources/model_cards/qwen3-coder-480b-a35b-8bit.toml
Normal file
15
resources/model_cards/qwen3-coder-480b-a35b-8bit.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
short_id = "qwen3-coder-480b-a35b-8bit"
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
|
||||
name = "Qwen3 Coder 480B A35B (8-bit)"
|
||||
description = "Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset."
|
||||
tags = []
|
||||
|
||||
[metadata]
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
|
||||
pretty_name = "Qwen3 Coder 480B A35B (8-bit)"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
|
||||
[metadata.storage_size]
|
||||
in_bytes = 579820584960
|
||||
@@ -1 +0,0 @@
|
||||
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""
|
||||
@@ -1,184 +0,0 @@
|
||||
"""Claude Messages API adapter for converting requests/responses."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
FinishReason,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
ClaudeContentBlockStartEvent,
|
||||
ClaudeContentBlockStopEvent,
|
||||
ClaudeMessageDelta,
|
||||
ClaudeMessageDeltaEvent,
|
||||
ClaudeMessageDeltaUsage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessagesResponse,
|
||||
ClaudeMessageStart,
|
||||
ClaudeMessageStartEvent,
|
||||
ClaudeMessageStopEvent,
|
||||
ClaudeStopReason,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeUsage,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
|
||||
|
||||
def finish_reason_to_claude_stop_reason(
|
||||
finish_reason: FinishReason | None,
|
||||
) -> ClaudeStopReason | None:
|
||||
"""Map OpenAI finish_reason to Claude stop_reason."""
|
||||
if finish_reason is None:
|
||||
return None
|
||||
mapping: dict[FinishReason, ClaudeStopReason] = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
"content_filter": "end_turn",
|
||||
"function_call": "tool_use",
|
||||
}
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
|
||||
def claude_request_to_chat_params(
|
||||
request: ClaudeMessagesRequest,
|
||||
) -> ChatCompletionTaskParams:
|
||||
"""Convert Claude Messages API request to internal ChatCompletionTaskParams."""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
|
||||
# Add system message if present
|
||||
if request.system:
|
||||
if isinstance(request.system, str):
|
||||
messages.append(
|
||||
ChatCompletionMessage(role="system", content=request.system)
|
||||
)
|
||||
else:
|
||||
# List of text blocks
|
||||
system_text = "".join(block.text for block in request.system)
|
||||
messages.append(ChatCompletionMessage(role="system", content=system_text))
|
||||
|
||||
# Convert messages
|
||||
for msg in request.messages:
|
||||
content: str
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
else:
|
||||
# Concatenate text blocks (images not supported for MVP)
|
||||
text_parts: list[str] = []
|
||||
for block in msg.content:
|
||||
if isinstance(block, ClaudeTextBlock):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
messages.append(ChatCompletionMessage(role=msg.role, content=content))
|
||||
|
||||
return ChatCompletionTaskParams(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop_sequences,
|
||||
stream=request.stream,
|
||||
)
|
||||
|
||||
|
||||
def chat_response_to_claude_response(
|
||||
response: ChatCompletionResponse,
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Convert internal ChatCompletionResponse to Claude Messages API response."""
|
||||
content_text = ""
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
|
||||
if response.choices:
|
||||
choice = response.choices[0]
|
||||
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
|
||||
content_text = (
|
||||
choice.message.content
|
||||
if isinstance(choice.message.content, str)
|
||||
else str(choice.message.content)
|
||||
)
|
||||
stop_reason = finish_reason_to_claude_stop_reason(choice.finish_reason)
|
||||
|
||||
# Use actual usage data from response if available
|
||||
input_tokens = response.usage.prompt_tokens if response.usage else 0
|
||||
output_tokens = response.usage.completion_tokens if response.usage else 0
|
||||
|
||||
return ClaudeMessagesResponse(
|
||||
id=f"msg_{response.id}",
|
||||
model=response.model,
|
||||
content=[ClaudeTextBlock(text=content_text)],
|
||||
stop_reason=stop_reason,
|
||||
usage=ClaudeUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Claude Messages API streaming events from TokenChunks."""
|
||||
# Initial message_start event
|
||||
initial_message = ClaudeMessageStart(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=[],
|
||||
stop_reason=None,
|
||||
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
|
||||
)
|
||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_start
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
|
||||
# content_block_stop
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# message_delta
|
||||
message_delta = ClaudeMessageDeltaEvent(
|
||||
delta=ClaudeMessageDelta(stop_reason=stop_reason),
|
||||
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
|
||||
)
|
||||
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
|
||||
|
||||
# message_stop
|
||||
message_stop = ClaudeMessageStopEvent()
|
||||
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"
|
||||
@@ -1,199 +0,0 @@
|
||||
"""OpenAI Responses API adapter for converting requests/responses."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseMessageItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
|
||||
|
||||
def responses_request_to_chat_params(
|
||||
request: ResponsesRequest,
|
||||
) -> ChatCompletionTaskParams:
|
||||
"""Convert OpenAI Responses API request to internal ChatCompletionTaskParams."""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
|
||||
# Add instructions as system message if present
|
||||
if request.instructions:
|
||||
messages.append(
|
||||
ChatCompletionMessage(role="system", content=request.instructions)
|
||||
)
|
||||
|
||||
# Convert input to messages
|
||||
if isinstance(request.input, str):
|
||||
messages.append(ChatCompletionMessage(role="user", content=request.input))
|
||||
else:
|
||||
for msg in request.input:
|
||||
messages.append(
|
||||
ChatCompletionMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
)
|
||||
)
|
||||
|
||||
return ChatCompletionTaskParams(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
max_tokens=request.max_output_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stream=request.stream,
|
||||
)
|
||||
|
||||
|
||||
def chat_response_to_responses_response(
|
||||
response: ChatCompletionResponse,
|
||||
) -> ResponsesResponse:
|
||||
"""Convert internal ChatCompletionResponse to OpenAI Responses API response."""
|
||||
output_text = ""
|
||||
|
||||
if response.choices:
|
||||
choice = response.choices[0]
|
||||
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
|
||||
output_text = (
|
||||
choice.message.content
|
||||
if isinstance(choice.message.content, str)
|
||||
else str(choice.message.content)
|
||||
)
|
||||
|
||||
item_id = f"item_{response.id}"
|
||||
output_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=output_text)],
|
||||
)
|
||||
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=response.usage.prompt_tokens,
|
||||
output_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return ResponsesResponse(
|
||||
id=f"resp_{response.id}",
|
||||
model=response.model,
|
||||
output=[output_item],
|
||||
output_text=output_text,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
|
||||
# response.created
|
||||
initial_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="in_progress",
|
||||
output=[],
|
||||
output_text="",
|
||||
)
|
||||
created_event = ResponseCreatedEvent(response=initial_response)
|
||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.in_progress
|
||||
in_progress_event = ResponseInProgressEvent(response=initial_response)
|
||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(output_index=0, item=initial_item)
|
||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
output_index=0, content_index=0, part=initial_part
|
||||
)
|
||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||
|
||||
accumulated_text = ""
|
||||
last_stats = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
output_index=0, content_index=0, text=accumulated_text
|
||||
)
|
||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.done
|
||||
final_part = ResponseOutputText(text=accumulated_text)
|
||||
part_done = ResponseContentPartDoneEvent(
|
||||
output_index=0, content_index=0, part=final_part
|
||||
)
|
||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.done
|
||||
final_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
item_done = ResponseOutputItemDoneEvent(output_index=0, item=final_item)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from stats if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
final_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=[final_item],
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
completed_event = ResponseCompletedEvent(response=final_response)
|
||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||
@@ -13,17 +13,13 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.adapters.claude import (
|
||||
chat_response_to_claude_response,
|
||||
claude_request_to_chat_params,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.master.adapters.responses import (
|
||||
chat_response_to_responses_response,
|
||||
generate_responses_stream,
|
||||
responses_request_to_chat_params,
|
||||
)
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.election import ElectionMessage
|
||||
@@ -41,8 +37,6 @@ from exo.shared.types.api import (
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -51,10 +45,6 @@ from exo.shared.types.api import (
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessagesResponse,
|
||||
)
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
@@ -68,10 +58,6 @@ from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
@@ -81,24 +67,12 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
# Build logprobs if available
|
||||
logprobs: Logprobs | None = None
|
||||
if chunk.logprob is not None:
|
||||
logprobs = Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
bytes=list(chunk.text.encode("utf-8")),
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -107,7 +81,6 @@ def chunk_to_response(
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -203,8 +176,6 @@ class API:
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
|
||||
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -410,8 +381,35 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
@@ -419,10 +417,16 @@ class API:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -438,11 +442,11 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
@@ -454,7 +458,7 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
@@ -462,7 +466,7 @@ class API:
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -491,7 +495,7 @@ class API:
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
@@ -499,7 +503,7 @@ class API:
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -540,6 +544,8 @@ class API:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -556,16 +562,17 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
@@ -582,78 +589,12 @@ class API:
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
return response
|
||||
|
||||
async def claude_messages(
|
||||
self, payload: ClaudeMessagesRequest
|
||||
) -> ClaudeMessagesResponse | StreamingResponse:
|
||||
"""Handle Claude Messages API requests."""
|
||||
chat_params = claude_request_to_chat_params(payload)
|
||||
model_meta = await resolve_model_meta(chat_params.model)
|
||||
chat_params.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == chat_params.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(chat_params.model)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No instance found for model {chat_params.model}",
|
||||
)
|
||||
|
||||
command = ChatCompletion(request_params=chat_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_claude_stream(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._chat_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
response = await self._collect_chat_completion(command.command_id)
|
||||
return chat_response_to_claude_response(response)
|
||||
|
||||
async def openai_responses(
|
||||
self, payload: ResponsesRequest
|
||||
) -> ResponsesResponse | StreamingResponse:
|
||||
"""Handle OpenAI Responses API requests."""
|
||||
chat_params = responses_request_to_chat_params(payload)
|
||||
|
||||
model_meta = await resolve_model_meta(chat_params.model)
|
||||
chat_params.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == chat_params.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(chat_params.model)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No instance found for model {chat_params.model}",
|
||||
)
|
||||
|
||||
command = ChatCompletion(request_params=chat_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_responses_stream(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._chat_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
response = await self._collect_chat_completion(command.command_id)
|
||||
return chat_response_to_responses_response(response)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
@@ -718,17 +659,9 @@ class API:
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
try:
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
except (anyio.BrokenResourceError, KeyError):
|
||||
# Client disconnected, queue was closed/removed - this is expected
|
||||
# when clients abort requests (e.g., regenerate from token)
|
||||
logger.debug(
|
||||
f"Client disconnected for command {event.command_id}, "
|
||||
"dropping chunk"
|
||||
)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -1,392 +0,0 @@
|
||||
"""Tests for Claude Messages API conversion functions and types."""
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.master.adapters.claude import (
|
||||
chat_response_to_claude_response,
|
||||
claude_request_to_chat_params,
|
||||
finish_reason_to_claude_stop_reason,
|
||||
)
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
ClaudeContentBlockStartEvent,
|
||||
ClaudeContentBlockStopEvent,
|
||||
ClaudeMessage,
|
||||
ClaudeMessageDelta,
|
||||
ClaudeMessageDeltaEvent,
|
||||
ClaudeMessageDeltaUsage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessageStart,
|
||||
ClaudeMessageStartEvent,
|
||||
ClaudeMessageStopEvent,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeUsage,
|
||||
)
|
||||
|
||||
|
||||
class TestFinishReasonToClaudeStopReason:
|
||||
"""Tests for finish_reason to Claude stop_reason mapping."""
|
||||
|
||||
def test_stop_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
|
||||
|
||||
def test_length_maps_to_max_tokens(self):
|
||||
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
|
||||
|
||||
def test_tool_calls_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
|
||||
|
||||
def test_function_call_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
|
||||
|
||||
def test_content_filter_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert finish_reason_to_claude_stop_reason(None) is None
|
||||
|
||||
|
||||
class TestClaudeRequestToChatParams:
|
||||
"""Tests for converting Claude Messages API requests to ChatCompletionTaskParams."""
|
||||
|
||||
def test_basic_request_conversion(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert params.model == "claude-3-opus"
|
||||
assert params.max_tokens == 100
|
||||
assert len(params.messages) == 1
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[0].content == "Hello"
|
||||
|
||||
def test_request_with_system_string(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
system="You are a helpful assistant.",
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[0].content == "You are a helpful assistant."
|
||||
assert params.messages[1].role == "user"
|
||||
assert params.messages[1].content == "Hello"
|
||||
|
||||
def test_request_with_system_text_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
system=[
|
||||
ClaudeTextBlock(text="You are helpful. "),
|
||||
ClaudeTextBlock(text="Be concise."),
|
||||
],
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[0].content == "You are helpful. Be concise."
|
||||
|
||||
def test_request_with_content_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(
|
||||
role="user",
|
||||
content=[
|
||||
ClaudeTextBlock(text="First part. "),
|
||||
ClaudeTextBlock(text="Second part."),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 1
|
||||
assert params.messages[0].content == "First part. Second part."
|
||||
|
||||
def test_request_with_multi_turn_conversation(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
ClaudeMessage(role="assistant", content="Hi there!"),
|
||||
ClaudeMessage(role="user", content="How are you?"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 3
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[1].role == "assistant"
|
||||
assert params.messages[2].role == "user"
|
||||
|
||||
def test_request_with_optional_parameters(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[ClaudeMessage(role="user", content="Hello")],
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
stop_sequences=["STOP", "END"],
|
||||
stream=True,
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert params.temperature == 0.7
|
||||
assert params.top_p == 0.9
|
||||
assert params.top_k == 40
|
||||
assert params.stop == ["STOP", "END"]
|
||||
assert params.stream is True
|
||||
|
||||
|
||||
class TestChatResponseToClaudeResponse:
|
||||
"""Tests for converting ChatCompletionResponse to Claude Messages API response."""
|
||||
|
||||
def test_basic_response_conversion(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello! How can I help you?",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(prompt_tokens=10, completion_tokens=7, total_tokens=17),
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.id == "msg_chatcmpl-123"
|
||||
assert claude_response.model == "llama-3.2-1b"
|
||||
assert claude_response.role == "assistant"
|
||||
assert claude_response.type == "message"
|
||||
assert len(claude_response.content) == 1
|
||||
assert claude_response.content[0].type == "text"
|
||||
assert claude_response.content[0].text == "Hello! How can I help you?"
|
||||
assert claude_response.stop_reason == "end_turn"
|
||||
assert claude_response.usage.input_tokens == 10
|
||||
assert claude_response.usage.output_tokens == 7
|
||||
|
||||
def test_response_with_length_finish_reason(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content="Truncated..."
|
||||
),
|
||||
finish_reason="length",
|
||||
)
|
||||
],
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.stop_reason == "max_tokens"
|
||||
|
||||
def test_response_with_empty_content(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=""),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(prompt_tokens=10, completion_tokens=0, total_tokens=10),
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.content[0].text == ""
|
||||
assert claude_response.usage.output_tokens == 0
|
||||
|
||||
def test_response_with_no_choices(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[],
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.content[0].text == ""
|
||||
assert claude_response.stop_reason is None
|
||||
assert claude_response.usage.input_tokens == 0
|
||||
assert claude_response.usage.output_tokens == 0
|
||||
|
||||
def test_response_without_usage(self):
|
||||
"""Test response conversion when usage data is not available."""
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.content[0].text == "Hello!"
|
||||
assert claude_response.usage.input_tokens == 0
|
||||
assert claude_response.usage.output_tokens == 0
|
||||
|
||||
|
||||
class TestClaudeMessagesRequestValidation:
|
||||
"""Tests for Claude Messages API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"max_tokens": 100,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_max_tokens(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_messages(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": 100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestClaudeStreamingEvents:
|
||||
"""Tests for Claude Messages API streaming event serialization."""
|
||||
|
||||
def test_message_start_event_format(self):
|
||||
message = ClaudeMessageStart(
|
||||
id="msg_123",
|
||||
model="claude-3-opus",
|
||||
content=[],
|
||||
stop_reason=None,
|
||||
usage=ClaudeUsage(input_tokens=10, output_tokens=0),
|
||||
)
|
||||
event = ClaudeMessageStartEvent(message=message)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "message_start"
|
||||
assert parsed["message"]["id"] == "msg_123"
|
||||
assert parsed["message"]["type"] == "message"
|
||||
assert parsed["message"]["role"] == "assistant"
|
||||
assert parsed["message"]["model"] == "claude-3-opus"
|
||||
|
||||
def test_content_block_start_event_format(self):
|
||||
event = ClaudeContentBlockStartEvent(
|
||||
index=0,
|
||||
content_block=ClaudeTextBlock(text=""),
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "content_block_start"
|
||||
assert parsed["index"] == 0
|
||||
assert parsed["content_block"]["type"] == "text"
|
||||
assert parsed["content_block"]["text"] == ""
|
||||
|
||||
def test_content_block_delta_event_format(self):
|
||||
event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text="Hello"),
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "content_block_delta"
|
||||
assert parsed["index"] == 0
|
||||
assert parsed["delta"]["type"] == "text_delta"
|
||||
assert parsed["delta"]["text"] == "Hello"
|
||||
|
||||
def test_content_block_stop_event_format(self):
|
||||
event = ClaudeContentBlockStopEvent(index=0)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "content_block_stop"
|
||||
assert parsed["index"] == 0
|
||||
|
||||
def test_message_delta_event_format(self):
|
||||
event = ClaudeMessageDeltaEvent(
|
||||
delta=ClaudeMessageDelta(stop_reason="end_turn"),
|
||||
usage=ClaudeMessageDeltaUsage(output_tokens=25),
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "message_delta"
|
||||
assert parsed["delta"]["stop_reason"] == "end_turn"
|
||||
assert parsed["usage"]["output_tokens"] == 25
|
||||
|
||||
def test_message_stop_event_format(self):
|
||||
event = ClaudeMessageStopEvent()
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "message_stop"
|
||||
|
||||
def test_sse_format(self):
|
||||
"""Test that SSE format is correctly generated."""
|
||||
event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text="Hello"),
|
||||
)
|
||||
# Simulate the SSE format used in the streaming generator
|
||||
sse_line = f"event: content_block_delta\ndata: {event.model_dump_json()}\n\n"
|
||||
|
||||
assert sse_line.startswith("event: content_block_delta\n")
|
||||
assert "data: " in sse_line
|
||||
assert sse_line.endswith("\n\n")
|
||||
@@ -1,414 +0,0 @@
|
||||
"""Tests for OpenAI Responses API conversion functions and types."""
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.master.adapters.responses import (
|
||||
chat_response_to_responses_response,
|
||||
responses_request_to_chat_params,
|
||||
)
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInputMessage,
|
||||
ResponseMessageItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
)
|
||||
|
||||
|
||||
class TestResponsesRequestToChatParams:
|
||||
"""Tests for converting OpenAI Responses API requests to ChatCompletionTaskParams."""
|
||||
|
||||
def test_string_input_conversion(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello, how are you?",
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert params.model == "gpt-4o"
|
||||
assert len(params.messages) == 1
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[0].content == "Hello, how are you?"
|
||||
|
||||
def test_message_array_input_conversion(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
ResponseInputMessage(role="user", content="Hello"),
|
||||
ResponseInputMessage(role="assistant", content="Hi there!"),
|
||||
ResponseInputMessage(role="user", content="How are you?"),
|
||||
],
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 3
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[0].content == "Hello"
|
||||
assert params.messages[1].role == "assistant"
|
||||
assert params.messages[1].content == "Hi there!"
|
||||
assert params.messages[2].role == "user"
|
||||
assert params.messages[2].content == "How are you?"
|
||||
|
||||
def test_request_with_instructions(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello",
|
||||
instructions="You are a helpful assistant. Be concise.",
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[0].content == "You are a helpful assistant. Be concise."
|
||||
assert params.messages[1].role == "user"
|
||||
assert params.messages[1].content == "Hello"
|
||||
|
||||
def test_request_with_optional_parameters(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello",
|
||||
max_output_tokens=500,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
stream=True,
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert params.max_tokens == 500
|
||||
assert params.temperature == 0.8
|
||||
assert params.top_p == 0.95
|
||||
assert params.stream is True
|
||||
|
||||
def test_request_with_system_role_in_messages(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
ResponseInputMessage(role="system", content="Be helpful"),
|
||||
ResponseInputMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[1].role == "user"
|
||||
|
||||
def test_request_with_developer_role(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
ResponseInputMessage(role="developer", content="Internal note"),
|
||||
ResponseInputMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "developer"
|
||||
|
||||
|
||||
class TestChatResponseToResponsesResponse:
|
||||
"""Tests for converting ChatCompletionResponse to OpenAI Responses API response."""
|
||||
|
||||
def test_basic_response_conversion(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello! How can I help you?",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.id == "resp_chatcmpl-123"
|
||||
assert responses_response.object == "response"
|
||||
assert responses_response.model == "llama-3.2-1b"
|
||||
assert responses_response.status == "completed"
|
||||
assert responses_response.output_text == "Hello! How can I help you?"
|
||||
assert len(responses_response.output) == 1
|
||||
assert responses_response.output[0].type == "message"
|
||||
assert responses_response.output[0].role == "assistant"
|
||||
assert len(responses_response.output[0].content) == 1
|
||||
assert responses_response.output[0].content[0].type == "output_text"
|
||||
assert (
|
||||
responses_response.output[0].content[0].text == "Hello! How can I help you?"
|
||||
)
|
||||
|
||||
def test_response_with_usage(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
),
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.usage is not None
|
||||
assert responses_response.usage.input_tokens == 10
|
||||
assert responses_response.usage.output_tokens == 5
|
||||
assert responses_response.usage.total_tokens == 15
|
||||
|
||||
def test_response_with_empty_content(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=""),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.output_text == ""
|
||||
assert responses_response.output[0].content[0].text == ""
|
||||
|
||||
def test_response_with_no_choices(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.output_text == ""
|
||||
|
||||
def test_response_without_usage(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.usage is None
|
||||
|
||||
def test_response_item_id_format(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-abc123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.output[0].id == "item_chatcmpl-abc123"
|
||||
|
||||
|
||||
class TestResponsesRequestValidation:
|
||||
"""Tests for OpenAI Responses API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"input": "Hello",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_input(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_accepts_string_input(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello",
|
||||
)
|
||||
assert request.input == "Hello"
|
||||
|
||||
def test_request_accepts_message_array_input(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[ResponseInputMessage(role="user", content="Hello")],
|
||||
)
|
||||
assert len(request.input) == 1
|
||||
|
||||
|
||||
class TestResponsesStreamingEvents:
|
||||
"""Tests for OpenAI Responses API streaming event serialization."""
|
||||
|
||||
def test_response_created_event_format(self):
|
||||
response = ResponsesResponse(
|
||||
id="resp_123",
|
||||
model="gpt-4o",
|
||||
status="in_progress",
|
||||
output=[],
|
||||
output_text="",
|
||||
)
|
||||
event = ResponseCreatedEvent(response=response)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.created"
|
||||
assert parsed["response"]["id"] == "resp_123"
|
||||
assert parsed["response"]["object"] == "response"
|
||||
assert parsed["response"]["status"] == "in_progress"
|
||||
|
||||
def test_output_item_added_event_format(self):
|
||||
item = ResponseMessageItem(
|
||||
id="item_123",
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
event = ResponseOutputItemAddedEvent(output_index=0, item=item)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_item.added"
|
||||
assert parsed["output_index"] == 0
|
||||
assert parsed["item"]["type"] == "message"
|
||||
assert parsed["item"]["id"] == "item_123"
|
||||
assert parsed["item"]["role"] == "assistant"
|
||||
|
||||
def test_content_part_added_event_format(self):
|
||||
part = ResponseOutputText(text="")
|
||||
event = ResponseContentPartAddedEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=part,
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.content_part.added"
|
||||
assert parsed["output_index"] == 0
|
||||
assert parsed["content_index"] == 0
|
||||
assert parsed["part"]["type"] == "output_text"
|
||||
|
||||
def test_text_delta_event_format(self):
|
||||
event = ResponseTextDeltaEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta="Hello",
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_text.delta"
|
||||
assert parsed["output_index"] == 0
|
||||
assert parsed["content_index"] == 0
|
||||
assert parsed["delta"] == "Hello"
|
||||
|
||||
def test_text_done_event_format(self):
|
||||
event = ResponseTextDoneEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
text="Hello, world!",
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_text.done"
|
||||
assert parsed["text"] == "Hello, world!"
|
||||
|
||||
def test_output_item_done_event_format(self):
|
||||
item = ResponseMessageItem(
|
||||
id="item_123",
|
||||
content=[ResponseOutputText(text="Hello, world!")],
|
||||
status="completed",
|
||||
)
|
||||
event = ResponseOutputItemDoneEvent(output_index=0, item=item)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_item.done"
|
||||
assert parsed["item"]["status"] == "completed"
|
||||
assert parsed["item"]["content"][0]["text"] == "Hello, world!"
|
||||
|
||||
def test_response_completed_event_format(self):
|
||||
item = ResponseMessageItem(
|
||||
id="item_123",
|
||||
content=[ResponseOutputText(text="Hello!")],
|
||||
status="completed",
|
||||
)
|
||||
response = ResponsesResponse(
|
||||
id="resp_123",
|
||||
model="gpt-4o",
|
||||
status="completed",
|
||||
output=[item],
|
||||
output_text="Hello!",
|
||||
usage=ResponseUsage(input_tokens=10, output_tokens=5, total_tokens=15),
|
||||
)
|
||||
event = ResponseCompletedEvent(response=response)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.completed"
|
||||
assert parsed["response"]["status"] == "completed"
|
||||
assert parsed["response"]["output_text"] == "Hello!"
|
||||
assert parsed["response"]["usage"]["total_tokens"] == 15
|
||||
|
||||
def test_sse_format(self):
|
||||
"""Test that SSE format is correctly generated."""
|
||||
event = ResponseTextDeltaEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta="Hello",
|
||||
)
|
||||
# Simulate the SSE format used in the streaming generator
|
||||
sse_line = (
|
||||
f"event: response.output_text.delta\ndata: {event.model_dump_json()}\n\n"
|
||||
)
|
||||
|
||||
assert sse_line.startswith("event: response.output_text.delta\n")
|
||||
assert "data: " in sse_line
|
||||
assert sse_line.endswith("\n\n")
|
||||
@@ -1,5 +1,8 @@
|
||||
from exo.shared.types.memory import Memory
|
||||
from anyio import Path, open_file
|
||||
import tomlkit
|
||||
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
@@ -11,542 +14,27 @@ class ModelCard(CamelCaseModel):
|
||||
tags: list[str]
|
||||
metadata: ModelMetadata
|
||||
|
||||
@staticmethod
|
||||
async def load(path: Path) -> "ModelCard":
|
||||
async with await open_file(path) as f:
|
||||
data = await f.read()
|
||||
py = tomlkit.loads(data)
|
||||
return ModelCard.model_validate(py)
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
short_id="deepseek-v3.1-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
name="DeepSeek V3.1 (4-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
pretty_name="DeepSeek V3.1 (4-bit)",
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
short_id="deepseek-v3.1-8bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
name="DeepSeek V3.1 (8-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
pretty_name="DeepSeek V3.1 (8-bit)",
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
name="Kimi K2 Instruct (4-bit)",
|
||||
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
pretty_name="Kimi K2 Instruct (4-bit)",
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
short_id="kimi-k2-thinking",
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
name="Kimi K2 Thinking (4-bit)",
|
||||
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
pretty_name="Kimi K2 Thinking (4-bit)",
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
short_id="llama-3.1-8b",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
name="Llama 3.1 8B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 8B (4-bit)",
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
short_id="llama-3.1-8b-8bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.1 8B (8-bit)",
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
short_id="llama-3.1-8b-bf16",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
pretty_name="Llama 3.1 8B (BF16)",
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
short_id="llama-3.1-70b",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
name="Llama 3.1 70B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 70B (4-bit)",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.2
|
||||
"llama-3.2-1b": ModelCard(
|
||||
short_id="llama-3.2-1b",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
name="Llama 3.2 1B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 1B (4-bit)",
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
short_id="llama-3.2-3b",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
name="Llama 3.2 3B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 3B (4-bit)",
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
short_id="llama-3.2-3b-8bit",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
name="Llama 3.2 3B (8-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.2 3B (8-bit)",
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.3
|
||||
"llama-3.3-70b": ModelCard(
|
||||
short_id="llama-3.3-70b",
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
name="Llama 3.3 70B (4-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
short_id="llama-3.3-70b-8bit",
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
name="Llama 3.3 70B (8-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.3 70B (8-bit)",
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
short_id="llama-3.3-70b-fp16",
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
name="Llama 3.3 70B (FP16)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
pretty_name="Llama 3.3 70B (FP16)",
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# qwen3
|
||||
"qwen3-0.6b": ModelCard(
|
||||
short_id="qwen3-0.6b",
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
name="Qwen3 0.6B (4-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
pretty_name="Qwen3 0.6B (4-bit)",
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
short_id="qwen3-0.6b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
name="Qwen3 0.6B (8-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
pretty_name="Qwen3 0.6B (8-bit)",
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
short_id="qwen3-30b",
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
name="Qwen3 30B A3B (4-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
pretty_name="Qwen3 30B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
short_id="qwen3-30b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
name="Qwen3 30B A3B (8-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
pretty_name="Qwen3 30B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
short_id="qwen3-235b-a22b-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
name="Qwen3 235B A22B (4-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
pretty_name="Qwen3 235B A22B (4-bit)",
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
short_id="qwen3-235b-a22b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
name="Qwen3 235B A22B (8-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
pretty_name="Qwen3 235B A22B (8-bit)",
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
short_id="qwen3-coder-480b-a35b-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
short_id="qwen3-coder-480b-a35b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-20b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
pretty_name="GLM 4.5 Air 8bit",
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
short_id="glm-4.5-air-bf16",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
pretty_name="GLM 4.5 Air bf16",
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
short_id="glm-4.7-4bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
name="GLM 4.7 4bit",
|
||||
description="GLM 4.7 4bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
pretty_name="GLM 4.7 4bit",
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
short_id="glm-4.7-6bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
name="GLM 4.7 6bit",
|
||||
description="GLM 4.7 6bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
pretty_name="GLM 4.7 6bit",
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
short_id="glm-4.7-8bit-gs32",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
name="GLM 4.7 8bit (gs32)",
|
||||
description="GLM 4.7 8bit (gs32)",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
pretty_name="GLM 4.7 8bit (gs32)",
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
short_id="minimax-m2.1-8bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
name="MiniMax M2.1 8bit",
|
||||
description="MiniMax M2.1 8bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
pretty_name="MiniMax M2.1 8bit",
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
short_id="minimax-m2.1-3bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
name="MiniMax M2.1 3bit",
|
||||
description="MiniMax M2.1 3bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
pretty_name="MiniMax M2.1 3bit",
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
}
|
||||
async def save(self, path: Path):
|
||||
async with await open_file(path, "w") as f:
|
||||
py = self.model_dump()
|
||||
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
|
||||
await f.write(data)
|
||||
|
||||
@staticmethod
|
||||
async def from_hf(model_id: str) -> "ModelCard":
|
||||
short_name = model_id.split("/")[-1]
|
||||
return ModelCard(
|
||||
short_id=short_name,
|
||||
model_id=ModelId(model_id),
|
||||
name=short_name,
|
||||
description=f"Custom model from {model_id}",
|
||||
tags=[],
|
||||
metadata=await get_model_meta(model_id),
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.worker.download.download_utils import (
|
||||
@@ -108,19 +107,13 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
model_card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelMetadata(
|
||||
model_id=ModelId(model_id),
|
||||
pretty_name=model_card.name if model_card is not None else model_id,
|
||||
pretty_name=model_id,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.metadata.supports_tensor
|
||||
if model_card is not None
|
||||
else False,
|
||||
supports_tensor=False,
|
||||
)
|
||||
|
||||
@@ -146,12 +146,10 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
stream: bool = False
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
continue_from_prefix: bool = False # When True, continue the last assistant message
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo.shared.types.api import GenerationStats, TopLogprobItem
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -20,8 +20,6 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
logprob: float | None = None # Log probability of the selected token
|
||||
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
"""Claude Messages API types for request/response conversion."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Type aliases
|
||||
ClaudeRole = Literal["user", "assistant"]
|
||||
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
|
||||
|
||||
|
||||
# Content block types
|
||||
class ClaudeTextBlock(BaseModel, frozen=True):
|
||||
"""Text content block in Claude Messages API."""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeImageSource(BaseModel, frozen=True):
|
||||
"""Image source for Claude image blocks."""
|
||||
|
||||
type: Literal["base64", "url"]
|
||||
media_type: str | None = None
|
||||
data: str | None = None
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class ClaudeImageBlock(BaseModel, frozen=True):
|
||||
"""Image content block in Claude Messages API."""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
source: ClaudeImageSource
|
||||
|
||||
|
||||
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock
|
||||
|
||||
|
||||
# Request types
|
||||
class ClaudeMessage(BaseModel, frozen=True):
|
||||
"""Message in Claude Messages API request."""
|
||||
|
||||
role: ClaudeRole
|
||||
content: str | list[ClaudeContentBlock]
|
||||
|
||||
|
||||
class ClaudeMessagesRequest(BaseModel):
|
||||
"""Request body for Claude Messages API."""
|
||||
|
||||
model: str
|
||||
max_tokens: int
|
||||
messages: list[ClaudeMessage]
|
||||
system: str | list[ClaudeTextBlock] | None = None
|
||||
stop_sequences: list[str] | None = None
|
||||
stream: bool = False
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
class ClaudeUsage(BaseModel, frozen=True):
|
||||
"""Token usage in Claude Messages API response."""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
|
||||
|
||||
class ClaudeMessagesResponse(BaseModel, frozen=True):
|
||||
"""Response body for Claude Messages API."""
|
||||
|
||||
id: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[ClaudeTextBlock]
|
||||
model: str
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
stop_sequence: str | None = None
|
||||
usage: ClaudeUsage
|
||||
|
||||
|
||||
# Streaming event types
|
||||
class ClaudeMessageStart(BaseModel, frozen=True):
|
||||
"""Partial message in message_start event."""
|
||||
|
||||
id: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[ClaudeTextBlock] = Field(default_factory=list)
|
||||
model: str
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
stop_sequence: str | None = None
|
||||
usage: ClaudeUsage
|
||||
|
||||
|
||||
class ClaudeMessageStartEvent(BaseModel, frozen=True):
|
||||
"""Event sent at start of message stream."""
|
||||
|
||||
type: Literal["message_start"] = "message_start"
|
||||
message: ClaudeMessageStart
|
||||
|
||||
|
||||
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
|
||||
"""Event sent at start of a content block."""
|
||||
|
||||
type: Literal["content_block_start"] = "content_block_start"
|
||||
index: int
|
||||
content_block: ClaudeTextBlock
|
||||
|
||||
|
||||
class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
"""Delta for text content block."""
|
||||
|
||||
type: Literal["text_delta"] = "text_delta"
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for content block delta."""
|
||||
|
||||
type: Literal["content_block_delta"] = "content_block_delta"
|
||||
index: int
|
||||
delta: ClaudeTextDelta
|
||||
|
||||
|
||||
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
|
||||
"""Event sent at end of a content block."""
|
||||
|
||||
type: Literal["content_block_stop"] = "content_block_stop"
|
||||
index: int
|
||||
|
||||
|
||||
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
|
||||
"""Usage in message_delta event."""
|
||||
|
||||
output_tokens: int
|
||||
|
||||
|
||||
class ClaudeMessageDelta(BaseModel, frozen=True):
|
||||
"""Delta in message_delta event."""
|
||||
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
stop_sequence: str | None = None
|
||||
|
||||
|
||||
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent with final message delta."""
|
||||
|
||||
type: Literal["message_delta"] = "message_delta"
|
||||
delta: ClaudeMessageDelta
|
||||
usage: ClaudeMessageDeltaUsage
|
||||
|
||||
|
||||
class ClaudeMessageStopEvent(BaseModel, frozen=True):
|
||||
"""Event sent at end of message stream."""
|
||||
|
||||
type: Literal["message_stop"] = "message_stop"
|
||||
|
||||
|
||||
ClaudeStreamEvent = (
|
||||
ClaudeMessageStartEvent
|
||||
| ClaudeContentBlockStartEvent
|
||||
| ClaudeContentBlockDeltaEvent
|
||||
| ClaudeContentBlockStopEvent
|
||||
| ClaudeMessageDeltaEvent
|
||||
| ClaudeMessageStopEvent
|
||||
)
|
||||
@@ -1,162 +0,0 @@
|
||||
"""OpenAI Responses API types for request/response conversion."""
|
||||
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Type aliases
|
||||
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
|
||||
ResponseRole = Literal["user", "assistant", "system", "developer"]
|
||||
|
||||
|
||||
# Request types
|
||||
class ResponseInputMessage(BaseModel, frozen=True):
|
||||
"""Input message for Responses API."""
|
||||
|
||||
role: ResponseRole
|
||||
content: str
|
||||
|
||||
|
||||
class ResponsesRequest(BaseModel):
|
||||
"""Request body for OpenAI Responses API."""
|
||||
|
||||
model: str
|
||||
input: str | list[ResponseInputMessage]
|
||||
instructions: str | None = None
|
||||
max_output_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
stream: bool = False
|
||||
# previous_response_id not supported in MVP
|
||||
metadata: dict[str, str] | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
class ResponseOutputText(BaseModel, frozen=True):
|
||||
"""Text content in response output."""
|
||||
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str
|
||||
annotations: list[dict[str, str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ResponseMessageItem(BaseModel, frozen=True):
|
||||
"""Message item in response output array."""
|
||||
|
||||
type: Literal["message"] = "message"
|
||||
id: str
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[ResponseOutputText]
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem # Can expand for function_call, reasoning, etc.
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel, frozen=True):
|
||||
"""Token usage in Responses API response."""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ResponsesResponse(BaseModel, frozen=True):
|
||||
"""Response body for OpenAI Responses API."""
|
||||
|
||||
id: str
|
||||
object: Literal["response"] = "response"
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
status: ResponseStatus = "completed"
|
||||
model: str
|
||||
output: list[ResponseItem]
|
||||
output_text: str
|
||||
usage: ResponseUsage | None = None
|
||||
|
||||
|
||||
# Streaming event types
|
||||
class ResponseCreatedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is created."""
|
||||
|
||||
type: Literal["response.created"] = "response.created"
|
||||
response: ResponsesResponse
|
||||
|
||||
|
||||
class ResponseInProgressEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response starts processing."""
|
||||
|
||||
type: Literal["response.in_progress"] = "response.in_progress"
|
||||
response: ResponsesResponse
|
||||
|
||||
|
||||
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when an output item is added."""
|
||||
|
||||
type: Literal["response.output_item.added"] = "response.output_item.added"
|
||||
output_index: int
|
||||
item: ResponseItem
|
||||
|
||||
|
||||
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a content part is added."""
|
||||
|
||||
type: Literal["response.content_part.added"] = "response.content_part.added"
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ResponseOutputText
|
||||
|
||||
|
||||
class ResponseTextDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for text delta during streaming."""
|
||||
|
||||
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||
output_index: int
|
||||
content_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseTextDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when text content is done."""
|
||||
|
||||
type: Literal["response.output_text.done"] = "response.output_text.done"
|
||||
output_index: int
|
||||
content_index: int
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a content part is done."""
|
||||
|
||||
type: Literal["response.content_part.done"] = "response.content_part.done"
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ResponseOutputText
|
||||
|
||||
|
||||
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when an output item is done."""
|
||||
|
||||
type: Literal["response.output_item.done"] = "response.output_item.done"
|
||||
output_index: int
|
||||
item: ResponseItem
|
||||
|
||||
|
||||
class ResponseCompletedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is completed."""
|
||||
|
||||
type: Literal["response.completed"] = "response.completed"
|
||||
response: ResponsesResponse
|
||||
|
||||
|
||||
ResponsesStreamEvent = (
|
||||
ResponseCreatedEvent
|
||||
| ResponseInProgressEvent
|
||||
| ResponseOutputItemAddedEvent
|
||||
| ResponseContentPartAddedEvent
|
||||
| ResponseTextDeltaEvent
|
||||
| ResponseTextDoneEvent
|
||||
| ResponseContentPartDoneEvent
|
||||
| ResponseOutputItemDoneEvent
|
||||
| ResponseCompletedEvent
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats, TopLogprobItem
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -13,8 +13,7 @@ class TokenizedResponse(BaseRunnerResponse):
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
logprob: float | None = None # Log probability of the selected token
|
||||
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
@@ -40,6 +40,4 @@ class TokenizerWrapper:
|
||||
messages_dicts: list[dict[str, Any]],
|
||||
tokenize: bool = False,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> str: ...
|
||||
|
||||
@@ -12,7 +12,6 @@ from exo.shared.types.api import (
|
||||
ChatCompletionMessage,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
@@ -116,60 +115,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs: mx.array,
|
||||
tokenizer: TokenizerWrapper,
|
||||
top_k: int,
|
||||
selected_token: int,
|
||||
) -> tuple[float, list[TopLogprobItem]]:
|
||||
"""Extract the selected token's logprob and top-k alternative tokens.
|
||||
|
||||
Args:
|
||||
logprobs: Full vocabulary logprobs array from MLX
|
||||
tokenizer: Tokenizer for decoding token IDs to strings
|
||||
top_k: Number of top alternatives to return
|
||||
selected_token: The token ID that was actually sampled
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_token_logprob, list of TopLogprobItem for top-k tokens)
|
||||
"""
|
||||
# Get the logprob of the selected token
|
||||
selected_logprob = float(logprobs[selected_token].item())
|
||||
|
||||
# Get top-k indices (most probable tokens)
|
||||
# mx.argpartition gives indices that would partition the array
|
||||
# We negate logprobs since argpartition finds smallest, and we want largest
|
||||
top_k = min(top_k, logprobs.shape[0]) # Don't exceed vocab size
|
||||
top_indices = mx.argpartition(-logprobs, top_k)[:top_k]
|
||||
|
||||
# Get the actual logprob values for these indices
|
||||
top_values = logprobs[top_indices]
|
||||
|
||||
# Sort by logprob (descending) for consistent ordering
|
||||
sort_order = mx.argsort(-top_values)
|
||||
top_indices = top_indices[sort_order]
|
||||
top_values = top_values[sort_order]
|
||||
|
||||
# Convert to list of TopLogprobItem
|
||||
top_logprob_items: list[TopLogprobItem] = []
|
||||
for i in range(top_k):
|
||||
token_id = int(top_indices[i].item())
|
||||
token_logprob = float(top_values[i].item())
|
||||
# Decode token ID to string
|
||||
token_str = tokenizer.decode([token_id])
|
||||
# Get byte representation
|
||||
token_bytes = list(token_str.encode("utf-8"))
|
||||
top_logprob_items.append(
|
||||
TopLogprobItem(
|
||||
token=token_str,
|
||||
logprob=token_logprob,
|
||||
bytes=token_bytes,
|
||||
)
|
||||
)
|
||||
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -201,24 +146,9 @@ def mlx_generate(
|
||||
sampler = make_sampler(
|
||||
temp=task.temperature if task.temperature is not None else 0.7,
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
top_k=task.top_k if task.top_k is not None else 0,
|
||||
)
|
||||
|
||||
# Normalize stop sequences to a list
|
||||
stop_sequences: list[str] = (
|
||||
([task.stop] if isinstance(task.stop, str) else task.stop)
|
||||
if task.stop is not None
|
||||
else []
|
||||
)
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
accumulated_text = ""
|
||||
|
||||
# Determine if we need to extract logprobs
|
||||
should_extract_logprobs = task.logprobs is True
|
||||
num_top_logprobs = task.top_logprobs if task.top_logprobs is not None else 5
|
||||
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -233,41 +163,9 @@ def mlx_generate(
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
accumulated_text += out.text
|
||||
|
||||
# Check for stop sequences
|
||||
text = out.text
|
||||
finish_reason: FinishReason | None = cast(
|
||||
FinishReason | None, out.finish_reason
|
||||
)
|
||||
stop_matched = False
|
||||
|
||||
if stop_sequences:
|
||||
for stop_seq in stop_sequences:
|
||||
if stop_seq in accumulated_text:
|
||||
# Trim text to just before the stop sequence
|
||||
stop_index = accumulated_text.find(stop_seq)
|
||||
text_before_stop = accumulated_text[:stop_index]
|
||||
chunk_start = len(accumulated_text) - len(out.text)
|
||||
text = text_before_stop[chunk_start:]
|
||||
finish_reason = "stop"
|
||||
stop_matched = True
|
||||
break
|
||||
|
||||
# Extract logprobs if requested
|
||||
token_logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if should_extract_logprobs:
|
||||
token_logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs=out.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
top_k=num_top_logprobs,
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
is_done = finish_reason is not None
|
||||
stats: GenerationStats | None = None
|
||||
if is_done:
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
@@ -275,25 +173,22 @@ def mlx_generate(
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
if not stop_matched and out.finish_reason not in get_args(FinishReason):
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
# We don't throw here as this failure case is really not all that bad
|
||||
# Just log the error and move on
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
logprob=token_logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -20,7 +20,6 @@ except ImportError:
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
@@ -359,28 +358,12 @@ def apply_chat_template(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
|
||||
# Use continue_final_message when continuing from prefix (e.g., regenerate from token)
|
||||
# This keeps the final assistant message open without EOS tokens
|
||||
# Note: explicitly set add_generation_prompt=False when using continue_final_message
|
||||
# because some tokenizers (e.g., Kimi) default add_generation_prompt=True
|
||||
prompt: str
|
||||
if chat_task_data.continue_from_prefix:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
continue_final_message=True,
|
||||
add_generation_prompt=False,
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
else:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -413,11 +396,6 @@ def make_kv_cache(
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
@@ -162,19 +153,11 @@ def main(
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
|
||||
for response in mlx_generator:
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
@@ -186,8 +169,6 @@ def main(
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
@@ -226,43 +207,6 @@ def main(
|
||||
break
|
||||
|
||||
|
||||
@cache
|
||||
def get_gpt_oss_encoding():
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return encoding
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
for response in responses:
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield response.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
break
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import http.client
|
||||
import time
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
@@ -7,8 +6,6 @@ from loguru import logger
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
BAD_STATUSLINE_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
@@ -18,9 +15,8 @@ async def check_reachability(
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
|
||||
# TODO: use an async http client
|
||||
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
@@ -36,16 +32,7 @@ async def check_reachability(
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.BadStatusLine:
|
||||
if attempt >= BAD_STATUSLINE_ATTEMPTS:
|
||||
logger.warning(
|
||||
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
|
||||
)
|
||||
return None
|
||||
time.sleep(1)
|
||||
return _fetch_remote_node_id(attempt=attempt + 1)
|
||||
except http.client.HTTPException as e:
|
||||
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
|
||||
except http.client.HTTPException:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
Reference in New Issue
Block a user