Compare commits

...

6 Commits

Author SHA1 Message Date
ciaranbor
55aed9b2e9 Reflect quantization level in conversation history 2026-01-24 15:40:05 +00:00
ciaranbor
a39742229c Use quantization param in FE to distinguish quantization variants for image models 2026-01-24 15:32:04 +00:00
ciaranbor
7902b4f24a Update API and placement to handle quantization param 2026-01-24 15:32:04 +00:00
ciaranbor
57e2734f99 Update DistributedImageModel to accept quantization parameter 2026-01-24 15:32:04 +00:00
ciaranbor
cd7d957d9f Generate quantized variable model cards for image models 2026-01-24 15:32:04 +00:00
Evan Quiney
d93db3d6bf re enable the evil network script (#1277)
seems like we still need the interfaces to be routable for mdns. at
least we're not dependent on this behaviour anymore.
2026-01-24 13:36:06 +00:00
11 changed files with 409 additions and 95 deletions

View File

@@ -31,6 +31,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

View File

@@ -112,18 +112,23 @@
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{ id: string; label: string; isImageModel: boolean }> =
[];
const models: Array<{
id: string;
quantization: number | null;
label: string;
isImageModel: boolean;
}> = [];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
const { modelId, quantization } = getInstanceModelInfo(instance);
if (
modelId &&
modelId !== "Unknown" &&
!models.some((m) => m.id === modelId)
!models.some((m) => m.id === modelId && m.quantization === quantization)
) {
models.push({
id: modelId,
label: modelId.split("/").pop() || modelId,
quantization,
label: `${modelId.split("/").pop() || modelId}${quantization ? ` (${quantization}-bit)` : ""}`,
isImageModel: modelSupportsImageGeneration(modelId),
});
}
@@ -145,20 +150,20 @@
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
setSelectedChatModel(models[0].id, models[0].quantization);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some((m) => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
setSelectedChatModel(models[0].id, models[0].quantization);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
setSelectedChatModel(newModels[0].id, newModels[0].quantization);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel("");
setSelectedChatModel("", null);
}
}
@@ -166,16 +171,23 @@
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {
if (!instanceWrapped || typeof instanceWrapped !== "object") return "";
function getInstanceModelInfo(instanceWrapped: unknown): {
modelId: string;
quantization: number | null;
} {
if (!instanceWrapped || typeof instanceWrapped !== "object")
return { modelId: "", quantization: null };
const keys = Object.keys(instanceWrapped as Record<string, unknown>);
if (keys.length === 1) {
const instance = (instanceWrapped as Record<string, unknown>)[
keys[0]
] as { shardAssignments?: { modelId?: string } };
return instance?.shardAssignments?.modelId || "";
] as { shardAssignments?: { modelId?: string; quantization?: number } };
return {
modelId: instance?.shardAssignments?.modelId || "",
quantization: instance?.shardAssignments?.quantization ?? null,
};
}
return "";
return { modelId: "", quantization: null };
}
async function handleFiles(files: File[]) {
@@ -469,7 +481,7 @@
<button
type="button"
onclick={() => {
setSelectedChatModel(model.id);
setSelectedChatModel(model.id, model.quantization);
isModelDropdownOpen = false;
}}
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {currentModel ===

View File

@@ -142,11 +142,15 @@
return null;
}
function formatModelName(modelId: string | null | undefined): string {
function formatModelName(
modelId: string | null | undefined,
quantization: number | null | undefined,
): string {
if (!modelId) return "Unknown Model";
const parts = modelId.split("/");
const tail = parts[parts.length - 1] || modelId;
return tail || modelId;
const baseName = tail || modelId;
return quantization ? `${baseName} (${quantization}-bit)` : baseName;
}
function formatStrategy(
@@ -244,7 +248,7 @@
conversation.instanceType ?? instanceDetails.instanceType;
return {
modelLabel: formatModelName(displayModel),
modelLabel: formatModelName(displayModel, conversation.quantization),
strategyLabel: formatStrategy(sharding, instanceType),
};
}

View File

@@ -162,6 +162,7 @@ export interface ModelDownloadStatus {
// Placement preview from the API
export interface PlacementPreview {
model_id: string;
quantization: number | null; // quantization bits or null for base model
sharding: "Pipeline" | "Tensor";
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
instance: unknown | null;
@@ -227,6 +228,7 @@ export interface Conversation {
createdAt: number;
updatedAt: number;
modelId: string | null;
quantization: number | null;
sharding: string | null;
instanceType: string | null;
}
@@ -491,6 +493,7 @@ class AppStore {
createdAt: conversation.createdAt ?? Date.now(),
updatedAt: conversation.updatedAt ?? Date.now(),
modelId: conversation.modelId ?? null,
quantization: conversation.quantization ?? null,
sharding: conversation.sharding ?? null,
instanceType: conversation.instanceType ?? null,
}));
@@ -669,6 +672,7 @@ class AppStore {
createdAt: now,
updatedAt: now,
modelId: derivedModelId,
quantization: this.selectedChatModelQuantization,
sharding: derivedSharding,
instanceType: derivedInstanceType,
};
@@ -1476,6 +1480,7 @@ class AppStore {
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: modelToUse,
quantization: this.selectedChatModelQuantization,
messages: apiMessages,
stream: true,
}),
@@ -1561,11 +1566,17 @@ class AppStore {
*/
selectedChatModel = $state("");
/**
* Selected model's quantization (null for base/unquantized models)
*/
selectedChatModelQuantization = $state<number | null>(null);
/**
* Set the model to use for chat
*/
setSelectedModel(modelId: string) {
setSelectedModel(modelId: string, quantization: number | null = null) {
this.selectedChatModel = modelId;
this.selectedChatModelQuantization = quantization;
// Clear stats when model changes
this.ttftMs = null;
this.tps = null;
@@ -1876,6 +1887,7 @@ class AppStore {
},
body: JSON.stringify({
model: modelToUse,
quantization: this.selectedChatModelQuantization,
messages: apiMessages,
temperature: 0.7,
stream: true,
@@ -2058,6 +2070,7 @@ class AppStore {
const requestBody: Record<string, unknown> = {
model,
quantization: this.selectedChatModelQuantization,
prompt,
n: params.numImages,
quality: params.quality,
@@ -2284,6 +2297,12 @@ class AppStore {
// Build FormData request
const formData = new FormData();
formData.append("model", model);
if (this.selectedChatModelQuantization !== null) {
formData.append(
"quantization",
this.selectedChatModelQuantization.toString(),
);
}
formData.append("prompt", prompt);
formData.append("image", imageBlob, "image.png");
@@ -2513,6 +2532,8 @@ export const isLoadingPreviews = () => appStore.isLoadingPreviews;
export const lastUpdate = () => appStore.lastUpdate;
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
export const selectedChatModel = () => appStore.selectedChatModel;
export const selectedChatModelQuantization = () =>
appStore.selectedChatModelQuantization;
export const debugMode = () => appStore.getDebugMode();
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
@@ -2541,8 +2562,10 @@ export const setEditingImage = (imageDataUrl: string, sourceMessage: Message) =>
appStore.setEditingImage(imageDataUrl, sourceMessage);
export const clearEditingImage = () => appStore.clearEditingImage();
export const clearChat = () => appStore.clearChat();
export const setSelectedChatModel = (modelId: string) =>
appStore.setSelectedModel(modelId);
export const setSelectedChatModel = (
modelId: string,
quantization: number | null = null,
) => appStore.setSelectedModel(modelId, quantization);
export const selectPreviewModel = (modelId: string | null) =>
appStore.selectPreviewModel(modelId);
export const togglePreviewNodeFilter = (nodeId: string) =>

View File

@@ -96,6 +96,7 @@
let models = $state<
Array<{
id: string;
quantization?: number | null;
name?: string;
storage_size_megabytes?: number;
tasks?: string[];
@@ -103,12 +104,38 @@
}>
>([]);
// Helper to get unique model key (combines id + quantization)
function getModelKey(model: {
id: string;
quantization?: number | null;
}): string {
return model.quantization != null
? `${model.id}-q${model.quantization}`
: model.id;
}
// Helper to get display name with quantization suffix
function getModelDisplayName(model: {
id: string;
name?: string;
quantization?: number | null;
}): string {
const baseName = model.name || model.id;
if (model.quantization != null) {
return `${baseName} (${model.quantization}-bit)`;
}
return baseName;
}
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
// Map by unique key (model_id + quantization)
const key = getModelKey(model);
tasks[key] = model.tasks;
// Also map by short ID (for backwards compatibility)
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
@@ -146,6 +173,7 @@
const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults";
interface LaunchDefaults {
modelId: string | null;
quantization: number | null;
sharding: "Pipeline" | "Tensor";
instanceType: InstanceMeta;
minNodes: number;
@@ -154,6 +182,7 @@
function saveLaunchDefaults(): void {
const defaults: LaunchDefaults = {
modelId: selectedPreviewModelId(),
quantization: selectedQuantization,
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
@@ -177,7 +206,7 @@
}
function applyLaunchDefaults(
availableModels: Array<{ id: string }>,
availableModels: Array<{ id: string; quantization?: number | null }>,
maxNodes: number,
): void {
const defaults = loadLaunchDefaults();
@@ -196,12 +225,17 @@
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (
defaults.modelId &&
availableModels.some((m) => m.id === defaults.modelId)
) {
selectPreviewModel(defaults.modelId);
// Only apply model if it exists in the available models (matching both id and quantization)
if (defaults.modelId) {
const matchingModel = availableModels.find(
(m) =>
m.id === defaults.modelId &&
(m.quantization ?? null) === (defaults.quantization ?? null),
);
if (matchingModel) {
selectPreviewModel(defaults.modelId);
selectedQuantization = defaults.quantization ?? null;
}
}
}
@@ -209,6 +243,7 @@
let selectedMinNodes = $state<number>(1);
let minNodesInitialized = $state(false);
let launchingModelId = $state<string | null>(null);
let selectedQuantization = $state<number | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Custom dropdown state
@@ -469,39 +504,40 @@
if (models.length === 0) return tags;
// Find the fastest model (highest TPS)
let fastestId = "";
let fastestKey = "";
let highestTps = 0;
// Find the biggest model (most memory)
let biggestId = "";
let biggestKey = "";
let highestMemory = 0;
for (const model of models) {
const key = getModelKey(model);
const perf = estimatePerformance(model.id);
const mem = getModelSizeGB(model);
if (perf.tps > highestTps) {
highestTps = perf.tps;
fastestId = model.id;
fastestKey = key;
}
if (mem > highestMemory) {
highestMemory = mem;
biggestId = model.id;
biggestKey = key;
}
}
if (fastestId) {
tags[fastestId] = tags[fastestId] || [];
tags[fastestId].push("FASTEST");
if (fastestKey) {
tags[fastestKey] = tags[fastestKey] || [];
tags[fastestKey].push("FASTEST");
}
if (biggestId && biggestId !== fastestId) {
tags[biggestId] = tags[biggestId] || [];
tags[biggestId].push("BIGGEST");
} else if (biggestId === fastestId && biggestId) {
if (biggestKey && biggestKey !== fastestKey) {
tags[biggestKey] = tags[biggestKey] || [];
tags[biggestKey].push("BIGGEST");
} else if (biggestKey === fastestKey && biggestKey) {
// Same model is both - unlikely but handle it
tags[biggestId].push("BIGGEST");
tags[biggestKey].push("BIGGEST");
}
return tags;
@@ -531,12 +567,13 @@
}
async function launchInstance(
modelId: string,
model: { id: string; quantization?: number | null },
specificPreview?: PlacementPreview | null,
) {
if (!modelId || launchingModelId) return;
const modelKey = getModelKey(model);
if (!model.id || launchingModelId) return;
launchingModelId = modelId;
launchingModelId = modelKey;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
@@ -550,7 +587,7 @@
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
`/instance/placement?model_id=${encodeURIComponent(model.id)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
);
if (!placementResponse.ok) {
@@ -574,7 +611,7 @@
console.error("Failed to launch instance:", errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
setSelectedChatModel(model.id, model.quantization ?? null);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -1074,19 +1111,20 @@
const [, lastInstance] =
remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
const newQuantization = getInstanceQuantization(lastInstance);
if (
newModelId &&
newModelId !== "Unknown" &&
newModelId !== "Unknown Model"
) {
setSelectedChatModel(newModelId);
setSelectedChatModel(newModelId, newQuantization);
} else {
// Clear selection if no valid model found
setSelectedChatModel("");
setSelectedChatModel("", null);
}
} else {
// No more instances, clear the selection
setSelectedChatModel("");
setSelectedChatModel("", null);
}
}
} catch (error) {
@@ -1112,6 +1150,16 @@
return inst.shardAssignments?.modelId || "Unknown Model";
}
// Get quantization from an instance
function getInstanceQuantization(instanceWrapped: unknown): number | null {
const [, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== "object") return null;
const inst = instance as {
shardAssignments?: { quantization?: number | null };
};
return inst.shardAssignments?.quantization ?? null;
}
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
@@ -1533,15 +1581,16 @@
// Get ALL filtered previews based on current settings (matching minimum nodes)
// Note: previewsData already contains previews for the selected model (fetched via API)
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
// Backend handles node_ids filtering, we filter by sharding/instance type, quantization, and min nodes
const filteredPreviews = $derived(() => {
if (!selectedModelId || previewsData.length === 0) return [];
// Find previews matching sharding/instance type (model_id filter not needed since previewsData is already for selected model)
// Find previews matching sharding/instance type and quantization
const matchingPreviews = previewsData.filter(
(p: PlacementPreview) =>
p.sharding === selectedSharding &&
matchesSelectedRuntime(p.instance_meta) &&
p.quantization === selectedQuantization &&
p.error === null &&
p.memory_delta_by_node !== null,
);
@@ -1948,6 +1997,8 @@
{@const isRunning = statusText === "RUNNING"}
<!-- Instance Card -->
{@const instanceModelId = getInstanceModelId(instance)}
{@const instanceQuantization =
getInstanceQuantization(instance)}
{@const instanceInfo = getInstanceInfo(instance)}
{@const instanceConnections =
getInstanceConnections(instance)}
@@ -1963,7 +2014,10 @@
instanceModelId !== "Unknown" &&
instanceModelId !== "Unknown Model"
) {
setSelectedChatModel(instanceModelId);
setSelectedChatModel(
instanceModelId,
instanceQuantization,
);
}
}}
onkeydown={(e) => {
@@ -1973,7 +2027,10 @@
instanceModelId !== "Unknown" &&
instanceModelId !== "Unknown Model"
) {
setSelectedChatModel(instanceModelId);
setSelectedChatModel(
instanceModelId,
instanceQuantization,
);
}
}
}}
@@ -2064,7 +2121,9 @@
<div
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
>
{getInstanceModelId(instance)}
{getInstanceModelId(instance)}{instanceQuantization
? ` (${instanceQuantization}-bit)`
: ""}
</div>
<div class="text-white/60 text-xs font-mono">
Strategy: <span class="text-white/80"
@@ -2371,7 +2430,9 @@
>
{#if selectedModelId}
{@const foundModel = models.find(
(m) => m.id === selectedModelId,
(m) =>
m.id === selectedModelId &&
(m.quantization ?? null) === selectedQuantization,
)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
@@ -2424,7 +2485,7 @@
</svg>
{/if}
<span class="truncate"
>{foundModel.name || foundModel.id}</span
>{getModelDisplayName(foundModel)}</span
>
</span>
<span class="text-white/50 text-xs flex-shrink-0"
@@ -2503,6 +2564,7 @@
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
selectedQuantization = model.quantization ?? null;
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = "";
@@ -2510,7 +2572,8 @@
}}
disabled={!modelCanFit}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedModelId ===
model.id
model.id &&
selectedQuantization === (model.quantization ?? null)
? 'bg-transparent text-exo-yellow cursor-pointer'
: modelCanFit
? 'text-white/80 hover:text-exo-yellow cursor-pointer'
@@ -2555,7 +2618,9 @@
/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
<span class="truncate"
>{getModelDisplayName(model)}</span
>
</span>
<span
class="flex-shrink-0 text-xs {modelCanFit
@@ -2765,14 +2830,16 @@
</div>
{:else}
{@const selectedModel = models.find(
(m) => m.id === selectedModelId,
(m) =>
m.id === selectedModelId &&
(m.quantization ?? null) === selectedQuantization,
)}
{@const allPreviews = filteredPreviews()}
{#if selectedModel && allPreviews.length > 0}
{@const downloadStatus = getModelDownloadStatus(
selectedModel.id,
)}
{@const tags = modelTags()[selectedModel.id] || []}
{@const tags = modelTags()[getModelKey(selectedModel)] || []}
<div class="space-y-3">
{#each allPreviews as apiPreview, i}
<div
@@ -2790,13 +2857,14 @@
>
<ModelCard
model={selectedModel}
isLaunching={launchingModelId === selectedModel.id}
isLaunching={launchingModelId ===
getModelKey(selectedModel)}
{downloadStatus}
nodes={data?.nodes ?? {}}
sharding={apiPreview.sharding}
runtime={apiPreview.instance_meta}
onLaunch={() =>
launchInstance(selectedModel.id, apiPreview)}
launchInstance(selectedModel, apiPreview)}
{tags}
{apiPreview}
modelIdOverride={apiPreview.model_id}
@@ -2945,6 +3013,8 @@
{@const isRunning = statusText === "RUNNING"}
<!-- Instance Card -->
{@const instanceModelId = getInstanceModelId(instance)}
{@const instanceQuantization =
getInstanceQuantization(instance)}
{@const instanceInfo = getInstanceInfo(instance)}
{@const instanceConnections =
getInstanceConnections(instance)}
@@ -2960,7 +3030,10 @@
instanceModelId !== "Unknown" &&
instanceModelId !== "Unknown Model"
) {
setSelectedChatModel(instanceModelId);
setSelectedChatModel(
instanceModelId,
instanceQuantization,
);
}
}}
onkeydown={(e) => {
@@ -2970,7 +3043,10 @@
instanceModelId !== "Unknown" &&
instanceModelId !== "Unknown Model"
) {
setSelectedChatModel(instanceModelId);
setSelectedChatModel(
instanceModelId,
instanceQuantization,
);
}
}
}}
@@ -3061,7 +3137,9 @@
<div
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
>
{getInstanceModelId(instance)}
{getInstanceModelId(instance)}{instanceQuantization
? ` (${instanceQuantization}-bit)`
: ""}
</div>
<div class="text-white/60 text-xs font-mono">
Strategy: <span class="text-white/80"

View File

@@ -141,14 +141,19 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
async def resolve_model_card(
model_id: ModelId, quantization: int | None = None
) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
if quantization is None and card.quantization is None:
return card
if card.quantization == quantization:
return card
return await ModelCard.from_hf(model_id)
@@ -354,7 +359,7 @@ class API:
model_id: ModelId,
node_ids: Annotated[list[NodeId] | None, Query()] = None,
) -> PlacementPreviewResponse:
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
seen: set[tuple[ModelId, int | None, Sharding, InstanceMeta, int]] = set()
previews: list[PlacementPreview] = []
required_nodes = set(node_ids) if node_ids else None
@@ -396,17 +401,32 @@ class API:
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
0,
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
quantization=model_card.quantization,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add(
(
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
0,
)
)
continue
current_ids = set(self.state.instances.keys())
@@ -417,17 +437,32 @@ class API:
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
0,
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
quantization=model_card.quantization,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add(
(
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
0,
)
)
continue
instance = new_instances[0]
@@ -447,6 +482,7 @@ class API:
if (
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
len(placement_node_ids),
@@ -454,6 +490,7 @@ class API:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
quantization=model_card.quantization,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -464,6 +501,7 @@ class API:
seen.add(
(
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
len(placement_node_ids),
@@ -669,6 +707,18 @@ class API:
"TODO: we should send a notification to the user to download the model"
)
def _has_matching_instance(
self, model_id: ModelId, quantization: int | None
) -> bool:
"""Check if there's a running instance matching the model_id and quantization."""
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id == model_id
and instance.shard_assignments.quantization == quantization
):
return True
return False
async def chat_completions(
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
@@ -720,22 +770,23 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
"""Validate model exists and return resolved model ID.
async def _validate_image_model(
self, model: str, quantization: int | None = None
) -> tuple[ModelId, int | None]:
"""Validate model exists and return resolved model ID and quantization.
Raises HTTPException 404 if no instance is found for the model.
Returns tuple of (model_id, quantization).
"""
model_card = await resolve_model_card(ModelId(model))
model_card = await resolve_model_card(ModelId(model), quantization)
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
resolved_quant = model_card.quantization
if not self._has_matching_instance(ModelId(resolved_model), resolved_quant):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
return resolved_model
return resolved_model, resolved_quant
async def get_image(self, image_id: str) -> FileResponse:
stored = self._image_store.get(Id(image_id))
@@ -771,7 +822,9 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
payload.model, payload.quantization = await self._validate_image_model(
payload.model, payload.quantization
)
command = ImageGeneration(
request_params=payload,
@@ -1016,7 +1069,9 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(payload.model)
payload.model, payload.quantization = await self._validate_image_model(
payload.model, payload.quantization
)
payload.stream = False
payload.partial_images = 0
@@ -1038,6 +1093,7 @@ class API:
image: UploadFile,
prompt: str,
model: str,
quantization: int | None,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
@@ -1050,7 +1106,9 @@ class API:
advanced_params: AdvancedImageParams | None,
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
resolved_model, resolved_quant = await self._validate_image_model(
model, quantization
)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
@@ -1069,6 +1127,7 @@ class API:
total_input_chunks=total_chunks,
prompt=prompt,
model=resolved_model,
quantization=resolved_quant,
n=n,
size=size,
response_format=response_format,
@@ -1107,6 +1166,7 @@ class API:
image: UploadFile = File(...), # noqa: B008
prompt: str = Form(...),
model: str = Form(...),
quantization: str | None = Form(None),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
@@ -1121,6 +1181,9 @@ class API:
# Parse string form values to proper types
stream_bool = stream.lower() in ("true", "1", "yes")
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
quantization_int = (
int(quantization) if quantization and quantization.isdigit() else None
)
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
@@ -1133,6 +1196,7 @@ class API:
image=image,
prompt=prompt,
model=model,
quantization=quantization_int,
n=n,
size=size,
response_format=response_format,
@@ -1169,6 +1233,7 @@ class API:
image: UploadFile = File(...), # noqa: B008
prompt: str = Form(...),
model: str = Form(...),
quantization: str | None = Form(None),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
@@ -1178,6 +1243,10 @@ class API:
advanced_params: str | None = Form(None),
) -> BenchImageGenerationResponse:
"""Handle benchmark image editing requests with generation stats."""
quantization_int = (
int(quantization) if quantization and quantization.isdigit() else None
)
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
@@ -1189,6 +1258,7 @@ class API:
image=image,
prompt=prompt,
model=model,
quantization=quantization_int,
n=n,
size=size,
response_format=response_format,
@@ -1223,6 +1293,7 @@ class API:
data=[
ModelListModel(
id=card.model_id,
quantization=card.quantization,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",

View File

@@ -137,6 +137,7 @@ def get_shard_assignments_for_pipeline_parallel(
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
quantization=model_card.quantization,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -170,6 +171,7 @@ def get_shard_assignments_for_tensor_parallel(
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
quantization=model_card.quantization,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)

View File

@@ -40,6 +40,7 @@ class ModelCard(CamelCaseModel):
supports_tensor: bool
tasks: list[ModelTask]
components: list[ComponentInfo] | None = None
quantization: int | None = None
@field_validator("tasks", mode="before")
@classmethod
@@ -413,7 +414,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
}
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
@@ -428,7 +429,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -442,7 +443,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -470,7 +471,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -484,7 +485,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -543,7 +544,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
@@ -551,10 +552,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -577,7 +578,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
@@ -585,10 +586,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -610,6 +611,91 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
),
}
def _create_image_model_quant_variants(
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
"""Create quantized variants of an image model card.
Only the transformer component is quantized; text encoders stay at bf16.
Sizes are calculated exactly from the base card's component sizes.
"""
if base_card.components is None:
raise ValueError(f"Image model {base_name} must have components defined")
quantizations = [8, 6, 5, 4, 3]
num_transformer_bytes = next(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name == "transformer"
)
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
remaining_bytes = Memory.from_bytes(
sum(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name != "transformer"
)
)
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
assert base_card.components is not None
return [
ComponentInfo(
component_name=c.component_name,
component_path=c.component_path,
storage_size=new_size
if c.component_name == "transformer"
else c.storage_size,
n_layers=c.n_layers,
can_shard=c.can_shard,
safetensors_index_filename=c.safetensors_index_filename,
)
for c in base_card.components
]
variants = {
base_name: ModelCard(
model_id=base_card.model_id,
storage_size=transformer_bytes + remaining_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(transformer_bytes),
quantization=None,
)
}
for quant in quantizations:
quant_transformer_bytes = Memory.from_bytes(
(num_transformer_bytes * quant) // 16
)
total_bytes = remaining_bytes + quant_transformer_bytes
variants[f"{base_name}-{quant}bit"] = ModelCard(
model_id=base_card.model_id,
storage_size=total_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(quant_transformer_bytes),
quantization=quant,
)
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _create_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)

View File

@@ -31,6 +31,7 @@ class ErrorResponse(BaseModel):
class ModelListModel(BaseModel):
id: str
quantization: int | None = None
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "exo"
@@ -216,6 +217,7 @@ class CreateInstanceParams(BaseModel):
class PlacementPreview(BaseModel):
model_id: ModelId
quantization: int | None = None
sharding: Sharding
instance_meta: InstanceMeta
instance: Instance | None = None
@@ -267,6 +269,7 @@ class ImageGenerationTaskParams(BaseModel):
style: str | None = "vivid"
user: str | None = None
advanced_params: AdvancedImageParams | None = None
quantization: int | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
@@ -292,6 +295,7 @@ class ImageEditsTaskParams(BaseModel):
stream: bool | None = False
user: str | None = None
advanced_params: AdvancedImageParams | None = None
quantization: int | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
@@ -303,6 +307,7 @@ class ImageEditsInternalParams(BaseModel):
total_input_chunks: int = 0
prompt: str
model: str
quantization: int | None = None
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"

View File

@@ -82,6 +82,7 @@ RunnerStatus = (
class ShardAssignments(CamelCaseModel):
model_id: ModelId
quantization: int | None = None
runner_to_shard: Mapping[RunnerId, ShardMetadata]
node_to_runner: Mapping[NodeId, RunnerId]

View File

@@ -71,8 +71,10 @@ class DistributedImageModel:
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_card.model_id
model_card = bound_instance.bound_shard.model_card
model_id = model_card.model_id
model_path = build_model_path(model_id)
quantize = model_card.quantization
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
@@ -93,6 +95,7 @@ class DistributedImageModel:
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
quantize=quantize,
)
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int: