mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-24 13:59:48 -05:00
Compare commits
6 Commits
runner-can
...
ciaran/ima
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55aed9b2e9 | ||
|
|
a39742229c | ||
|
|
7902b4f24a | ||
|
|
57e2734f99 | ||
|
|
cd7d957d9f | ||
|
|
d93db3d6bf |
@@ -31,6 +31,35 @@ enum NetworkSetupHelper {
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
|
||||
@@ -112,18 +112,23 @@
|
||||
|
||||
// Extract available models from running instances
|
||||
const availableModels = $derived(() => {
|
||||
const models: Array<{ id: string; label: string; isImageModel: boolean }> =
|
||||
[];
|
||||
const models: Array<{
|
||||
id: string;
|
||||
quantization: number | null;
|
||||
label: string;
|
||||
isImageModel: boolean;
|
||||
}> = [];
|
||||
for (const [, instance] of Object.entries(instanceData)) {
|
||||
const modelId = getInstanceModelId(instance);
|
||||
const { modelId, quantization } = getInstanceModelInfo(instance);
|
||||
if (
|
||||
modelId &&
|
||||
modelId !== "Unknown" &&
|
||||
!models.some((m) => m.id === modelId)
|
||||
!models.some((m) => m.id === modelId && m.quantization === quantization)
|
||||
) {
|
||||
models.push({
|
||||
id: modelId,
|
||||
label: modelId.split("/").pop() || modelId,
|
||||
quantization,
|
||||
label: `${modelId.split("/").pop() || modelId}${quantization ? ` (${quantization}-bit)` : ""}`,
|
||||
isImageModel: modelSupportsImageGeneration(modelId),
|
||||
});
|
||||
}
|
||||
@@ -145,20 +150,20 @@
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
setSelectedChatModel(models[0].id, models[0].quantization);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some((m) => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
setSelectedChatModel(models[0].id, models[0].quantization);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
setSelectedChatModel(newModels[0].id, newModels[0].quantization);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel("");
|
||||
setSelectedChatModel("", null);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,16 +171,23 @@
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object") return "";
|
||||
function getInstanceModelInfo(instanceWrapped: unknown): {
|
||||
modelId: string;
|
||||
quantization: number | null;
|
||||
} {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object")
|
||||
return { modelId: "", quantization: null };
|
||||
const keys = Object.keys(instanceWrapped as Record<string, unknown>);
|
||||
if (keys.length === 1) {
|
||||
const instance = (instanceWrapped as Record<string, unknown>)[
|
||||
keys[0]
|
||||
] as { shardAssignments?: { modelId?: string } };
|
||||
return instance?.shardAssignments?.modelId || "";
|
||||
] as { shardAssignments?: { modelId?: string; quantization?: number } };
|
||||
return {
|
||||
modelId: instance?.shardAssignments?.modelId || "",
|
||||
quantization: instance?.shardAssignments?.quantization ?? null,
|
||||
};
|
||||
}
|
||||
return "";
|
||||
return { modelId: "", quantization: null };
|
||||
}
|
||||
|
||||
async function handleFiles(files: File[]) {
|
||||
@@ -469,7 +481,7 @@
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
setSelectedChatModel(model.id);
|
||||
setSelectedChatModel(model.id, model.quantization);
|
||||
isModelDropdownOpen = false;
|
||||
}}
|
||||
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {currentModel ===
|
||||
|
||||
@@ -142,11 +142,15 @@
|
||||
return null;
|
||||
}
|
||||
|
||||
function formatModelName(modelId: string | null | undefined): string {
|
||||
function formatModelName(
|
||||
modelId: string | null | undefined,
|
||||
quantization: number | null | undefined,
|
||||
): string {
|
||||
if (!modelId) return "Unknown Model";
|
||||
const parts = modelId.split("/");
|
||||
const tail = parts[parts.length - 1] || modelId;
|
||||
return tail || modelId;
|
||||
const baseName = tail || modelId;
|
||||
return quantization ? `${baseName} (${quantization}-bit)` : baseName;
|
||||
}
|
||||
|
||||
function formatStrategy(
|
||||
@@ -244,7 +248,7 @@
|
||||
conversation.instanceType ?? instanceDetails.instanceType;
|
||||
|
||||
return {
|
||||
modelLabel: formatModelName(displayModel),
|
||||
modelLabel: formatModelName(displayModel, conversation.quantization),
|
||||
strategyLabel: formatStrategy(sharding, instanceType),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -162,6 +162,7 @@ export interface ModelDownloadStatus {
|
||||
// Placement preview from the API
|
||||
export interface PlacementPreview {
|
||||
model_id: string;
|
||||
quantization: number | null; // quantization bits or null for base model
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
instance: unknown | null;
|
||||
@@ -227,6 +228,7 @@ export interface Conversation {
|
||||
createdAt: number;
|
||||
updatedAt: number;
|
||||
modelId: string | null;
|
||||
quantization: number | null;
|
||||
sharding: string | null;
|
||||
instanceType: string | null;
|
||||
}
|
||||
@@ -491,6 +493,7 @@ class AppStore {
|
||||
createdAt: conversation.createdAt ?? Date.now(),
|
||||
updatedAt: conversation.updatedAt ?? Date.now(),
|
||||
modelId: conversation.modelId ?? null,
|
||||
quantization: conversation.quantization ?? null,
|
||||
sharding: conversation.sharding ?? null,
|
||||
instanceType: conversation.instanceType ?? null,
|
||||
}));
|
||||
@@ -669,6 +672,7 @@ class AppStore {
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
modelId: derivedModelId,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
sharding: derivedSharding,
|
||||
instanceType: derivedInstanceType,
|
||||
};
|
||||
@@ -1476,6 +1480,7 @@ class AppStore {
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
}),
|
||||
@@ -1561,11 +1566,17 @@ class AppStore {
|
||||
*/
|
||||
selectedChatModel = $state("");
|
||||
|
||||
/**
|
||||
* Selected model's quantization (null for base/unquantized models)
|
||||
*/
|
||||
selectedChatModelQuantization = $state<number | null>(null);
|
||||
|
||||
/**
|
||||
* Set the model to use for chat
|
||||
*/
|
||||
setSelectedModel(modelId: string) {
|
||||
setSelectedModel(modelId: string, quantization: number | null = null) {
|
||||
this.selectedChatModel = modelId;
|
||||
this.selectedChatModelQuantization = quantization;
|
||||
// Clear stats when model changes
|
||||
this.ttftMs = null;
|
||||
this.tps = null;
|
||||
@@ -1876,6 +1887,7 @@ class AppStore {
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
@@ -2058,6 +2070,7 @@ class AppStore {
|
||||
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
prompt,
|
||||
n: params.numImages,
|
||||
quality: params.quality,
|
||||
@@ -2284,6 +2297,12 @@ class AppStore {
|
||||
// Build FormData request
|
||||
const formData = new FormData();
|
||||
formData.append("model", model);
|
||||
if (this.selectedChatModelQuantization !== null) {
|
||||
formData.append(
|
||||
"quantization",
|
||||
this.selectedChatModelQuantization.toString(),
|
||||
);
|
||||
}
|
||||
formData.append("prompt", prompt);
|
||||
formData.append("image", imageBlob, "image.png");
|
||||
|
||||
@@ -2513,6 +2532,8 @@ export const isLoadingPreviews = () => appStore.isLoadingPreviews;
|
||||
export const lastUpdate = () => appStore.lastUpdate;
|
||||
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
|
||||
export const selectedChatModel = () => appStore.selectedChatModel;
|
||||
export const selectedChatModelQuantization = () =>
|
||||
appStore.selectedChatModelQuantization;
|
||||
export const debugMode = () => appStore.getDebugMode();
|
||||
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
@@ -2541,8 +2562,10 @@ export const setEditingImage = (imageDataUrl: string, sourceMessage: Message) =>
|
||||
appStore.setEditingImage(imageDataUrl, sourceMessage);
|
||||
export const clearEditingImage = () => appStore.clearEditingImage();
|
||||
export const clearChat = () => appStore.clearChat();
|
||||
export const setSelectedChatModel = (modelId: string) =>
|
||||
appStore.setSelectedModel(modelId);
|
||||
export const setSelectedChatModel = (
|
||||
modelId: string,
|
||||
quantization: number | null = null,
|
||||
) => appStore.setSelectedModel(modelId, quantization);
|
||||
export const selectPreviewModel = (modelId: string | null) =>
|
||||
appStore.selectPreviewModel(modelId);
|
||||
export const togglePreviewNodeFilter = (nodeId: string) =>
|
||||
|
||||
@@ -96,6 +96,7 @@
|
||||
let models = $state<
|
||||
Array<{
|
||||
id: string;
|
||||
quantization?: number | null;
|
||||
name?: string;
|
||||
storage_size_megabytes?: number;
|
||||
tasks?: string[];
|
||||
@@ -103,12 +104,38 @@
|
||||
}>
|
||||
>([]);
|
||||
|
||||
// Helper to get unique model key (combines id + quantization)
|
||||
function getModelKey(model: {
|
||||
id: string;
|
||||
quantization?: number | null;
|
||||
}): string {
|
||||
return model.quantization != null
|
||||
? `${model.id}-q${model.quantization}`
|
||||
: model.id;
|
||||
}
|
||||
|
||||
// Helper to get display name with quantization suffix
|
||||
function getModelDisplayName(model: {
|
||||
id: string;
|
||||
name?: string;
|
||||
quantization?: number | null;
|
||||
}): string {
|
||||
const baseName = model.name || model.id;
|
||||
if (model.quantization != null) {
|
||||
return `${baseName} (${model.quantization}-bit)`;
|
||||
}
|
||||
return baseName;
|
||||
}
|
||||
|
||||
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
|
||||
const modelTasks = $derived(() => {
|
||||
const tasks: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.tasks && model.tasks.length > 0) {
|
||||
// Map by short ID
|
||||
// Map by unique key (model_id + quantization)
|
||||
const key = getModelKey(model);
|
||||
tasks[key] = model.tasks;
|
||||
// Also map by short ID (for backwards compatibility)
|
||||
tasks[model.id] = model.tasks;
|
||||
// Also map by hugging_face_id from the API response
|
||||
if (model.hugging_face_id) {
|
||||
@@ -146,6 +173,7 @@
|
||||
const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults";
|
||||
interface LaunchDefaults {
|
||||
modelId: string | null;
|
||||
quantization: number | null;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
@@ -154,6 +182,7 @@
|
||||
function saveLaunchDefaults(): void {
|
||||
const defaults: LaunchDefaults = {
|
||||
modelId: selectedPreviewModelId(),
|
||||
quantization: selectedQuantization,
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
@@ -177,7 +206,7 @@
|
||||
}
|
||||
|
||||
function applyLaunchDefaults(
|
||||
availableModels: Array<{ id: string }>,
|
||||
availableModels: Array<{ id: string; quantization?: number | null }>,
|
||||
maxNodes: number,
|
||||
): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
@@ -196,12 +225,17 @@
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (
|
||||
defaults.modelId &&
|
||||
availableModels.some((m) => m.id === defaults.modelId)
|
||||
) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
// Only apply model if it exists in the available models (matching both id and quantization)
|
||||
if (defaults.modelId) {
|
||||
const matchingModel = availableModels.find(
|
||||
(m) =>
|
||||
m.id === defaults.modelId &&
|
||||
(m.quantization ?? null) === (defaults.quantization ?? null),
|
||||
);
|
||||
if (matchingModel) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
selectedQuantization = defaults.quantization ?? null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,6 +243,7 @@
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let minNodesInitialized = $state(false);
|
||||
let launchingModelId = $state<string | null>(null);
|
||||
let selectedQuantization = $state<number | null>(null);
|
||||
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
|
||||
// Custom dropdown state
|
||||
@@ -469,39 +504,40 @@
|
||||
if (models.length === 0) return tags;
|
||||
|
||||
// Find the fastest model (highest TPS)
|
||||
let fastestId = "";
|
||||
let fastestKey = "";
|
||||
let highestTps = 0;
|
||||
|
||||
// Find the biggest model (most memory)
|
||||
let biggestId = "";
|
||||
let biggestKey = "";
|
||||
let highestMemory = 0;
|
||||
|
||||
for (const model of models) {
|
||||
const key = getModelKey(model);
|
||||
const perf = estimatePerformance(model.id);
|
||||
const mem = getModelSizeGB(model);
|
||||
|
||||
if (perf.tps > highestTps) {
|
||||
highestTps = perf.tps;
|
||||
fastestId = model.id;
|
||||
fastestKey = key;
|
||||
}
|
||||
|
||||
if (mem > highestMemory) {
|
||||
highestMemory = mem;
|
||||
biggestId = model.id;
|
||||
biggestKey = key;
|
||||
}
|
||||
}
|
||||
|
||||
if (fastestId) {
|
||||
tags[fastestId] = tags[fastestId] || [];
|
||||
tags[fastestId].push("FASTEST");
|
||||
if (fastestKey) {
|
||||
tags[fastestKey] = tags[fastestKey] || [];
|
||||
tags[fastestKey].push("FASTEST");
|
||||
}
|
||||
|
||||
if (biggestId && biggestId !== fastestId) {
|
||||
tags[biggestId] = tags[biggestId] || [];
|
||||
tags[biggestId].push("BIGGEST");
|
||||
} else if (biggestId === fastestId && biggestId) {
|
||||
if (biggestKey && biggestKey !== fastestKey) {
|
||||
tags[biggestKey] = tags[biggestKey] || [];
|
||||
tags[biggestKey].push("BIGGEST");
|
||||
} else if (biggestKey === fastestKey && biggestKey) {
|
||||
// Same model is both - unlikely but handle it
|
||||
tags[biggestId].push("BIGGEST");
|
||||
tags[biggestKey].push("BIGGEST");
|
||||
}
|
||||
|
||||
return tags;
|
||||
@@ -531,12 +567,13 @@
|
||||
}
|
||||
|
||||
async function launchInstance(
|
||||
modelId: string,
|
||||
model: { id: string; quantization?: number | null },
|
||||
specificPreview?: PlacementPreview | null,
|
||||
) {
|
||||
if (!modelId || launchingModelId) return;
|
||||
const modelKey = getModelKey(model);
|
||||
if (!model.id || launchingModelId) return;
|
||||
|
||||
launchingModelId = modelId;
|
||||
launchingModelId = modelKey;
|
||||
|
||||
try {
|
||||
// Use the specific preview if provided, otherwise fall back to filtered preview
|
||||
@@ -550,7 +587,7 @@
|
||||
} else {
|
||||
// Fallback: GET placement from API
|
||||
const placementResponse = await fetch(
|
||||
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
|
||||
`/instance/placement?model_id=${encodeURIComponent(model.id)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
|
||||
);
|
||||
|
||||
if (!placementResponse.ok) {
|
||||
@@ -574,7 +611,7 @@
|
||||
console.error("Failed to launch instance:", errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
setSelectedChatModel(model.id, model.quantization ?? null);
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
@@ -1074,19 +1111,20 @@
|
||||
const [, lastInstance] =
|
||||
remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
const newQuantization = getInstanceQuantization(lastInstance);
|
||||
if (
|
||||
newModelId &&
|
||||
newModelId !== "Unknown" &&
|
||||
newModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(newModelId);
|
||||
setSelectedChatModel(newModelId, newQuantization);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel("");
|
||||
setSelectedChatModel("", null);
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel("");
|
||||
setSelectedChatModel("", null);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -1112,6 +1150,16 @@
|
||||
return inst.shardAssignments?.modelId || "Unknown Model";
|
||||
}
|
||||
|
||||
// Get quantization from an instance
|
||||
function getInstanceQuantization(instanceWrapped: unknown): number | null {
|
||||
const [, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== "object") return null;
|
||||
const inst = instance as {
|
||||
shardAssignments?: { quantization?: number | null };
|
||||
};
|
||||
return inst.shardAssignments?.quantization ?? null;
|
||||
}
|
||||
|
||||
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
@@ -1533,15 +1581,16 @@
|
||||
|
||||
// Get ALL filtered previews based on current settings (matching minimum nodes)
|
||||
// Note: previewsData already contains previews for the selected model (fetched via API)
|
||||
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
|
||||
// Backend handles node_ids filtering, we filter by sharding/instance type, quantization, and min nodes
|
||||
const filteredPreviews = $derived(() => {
|
||||
if (!selectedModelId || previewsData.length === 0) return [];
|
||||
|
||||
// Find previews matching sharding/instance type (model_id filter not needed since previewsData is already for selected model)
|
||||
// Find previews matching sharding/instance type and quantization
|
||||
const matchingPreviews = previewsData.filter(
|
||||
(p: PlacementPreview) =>
|
||||
p.sharding === selectedSharding &&
|
||||
matchesSelectedRuntime(p.instance_meta) &&
|
||||
p.quantization === selectedQuantization &&
|
||||
p.error === null &&
|
||||
p.memory_delta_by_node !== null,
|
||||
);
|
||||
@@ -1948,6 +1997,8 @@
|
||||
{@const isRunning = statusText === "RUNNING"}
|
||||
<!-- Instance Card -->
|
||||
{@const instanceModelId = getInstanceModelId(instance)}
|
||||
{@const instanceQuantization =
|
||||
getInstanceQuantization(instance)}
|
||||
{@const instanceInfo = getInstanceInfo(instance)}
|
||||
{@const instanceConnections =
|
||||
getInstanceConnections(instance)}
|
||||
@@ -1963,7 +2014,10 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(instanceModelId);
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
}
|
||||
}}
|
||||
onkeydown={(e) => {
|
||||
@@ -1973,7 +2027,10 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(instanceModelId);
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -2064,7 +2121,9 @@
|
||||
<div
|
||||
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
|
||||
>
|
||||
{getInstanceModelId(instance)}
|
||||
{getInstanceModelId(instance)}{instanceQuantization
|
||||
? ` (${instanceQuantization}-bit)`
|
||||
: ""}
|
||||
</div>
|
||||
<div class="text-white/60 text-xs font-mono">
|
||||
Strategy: <span class="text-white/80"
|
||||
@@ -2371,7 +2430,9 @@
|
||||
>
|
||||
{#if selectedModelId}
|
||||
{@const foundModel = models.find(
|
||||
(m) => m.id === selectedModelId,
|
||||
(m) =>
|
||||
m.id === selectedModelId &&
|
||||
(m.quantization ?? null) === selectedQuantization,
|
||||
)}
|
||||
{#if foundModel}
|
||||
{@const sizeGB = getModelSizeGB(foundModel)}
|
||||
@@ -2424,7 +2485,7 @@
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate"
|
||||
>{foundModel.name || foundModel.id}</span
|
||||
>{getModelDisplayName(foundModel)}</span
|
||||
>
|
||||
</span>
|
||||
<span class="text-white/50 text-xs flex-shrink-0"
|
||||
@@ -2503,6 +2564,7 @@
|
||||
onclick={() => {
|
||||
if (modelCanFit) {
|
||||
selectPreviewModel(model.id);
|
||||
selectedQuantization = model.quantization ?? null;
|
||||
saveLaunchDefaults();
|
||||
isModelDropdownOpen = false;
|
||||
modelDropdownSearch = "";
|
||||
@@ -2510,7 +2572,8 @@
|
||||
}}
|
||||
disabled={!modelCanFit}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedModelId ===
|
||||
model.id
|
||||
model.id &&
|
||||
selectedQuantization === (model.quantization ?? null)
|
||||
? 'bg-transparent text-exo-yellow cursor-pointer'
|
||||
: modelCanFit
|
||||
? 'text-white/80 hover:text-exo-yellow cursor-pointer'
|
||||
@@ -2555,7 +2618,9 @@
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="truncate"
|
||||
>{getModelDisplayName(model)}</span
|
||||
>
|
||||
</span>
|
||||
<span
|
||||
class="flex-shrink-0 text-xs {modelCanFit
|
||||
@@ -2765,14 +2830,16 @@
|
||||
</div>
|
||||
{:else}
|
||||
{@const selectedModel = models.find(
|
||||
(m) => m.id === selectedModelId,
|
||||
(m) =>
|
||||
m.id === selectedModelId &&
|
||||
(m.quantization ?? null) === selectedQuantization,
|
||||
)}
|
||||
{@const allPreviews = filteredPreviews()}
|
||||
{#if selectedModel && allPreviews.length > 0}
|
||||
{@const downloadStatus = getModelDownloadStatus(
|
||||
selectedModel.id,
|
||||
)}
|
||||
{@const tags = modelTags()[selectedModel.id] || []}
|
||||
{@const tags = modelTags()[getModelKey(selectedModel)] || []}
|
||||
<div class="space-y-3">
|
||||
{#each allPreviews as apiPreview, i}
|
||||
<div
|
||||
@@ -2790,13 +2857,14 @@
|
||||
>
|
||||
<ModelCard
|
||||
model={selectedModel}
|
||||
isLaunching={launchingModelId === selectedModel.id}
|
||||
isLaunching={launchingModelId ===
|
||||
getModelKey(selectedModel)}
|
||||
{downloadStatus}
|
||||
nodes={data?.nodes ?? {}}
|
||||
sharding={apiPreview.sharding}
|
||||
runtime={apiPreview.instance_meta}
|
||||
onLaunch={() =>
|
||||
launchInstance(selectedModel.id, apiPreview)}
|
||||
launchInstance(selectedModel, apiPreview)}
|
||||
{tags}
|
||||
{apiPreview}
|
||||
modelIdOverride={apiPreview.model_id}
|
||||
@@ -2945,6 +3013,8 @@
|
||||
{@const isRunning = statusText === "RUNNING"}
|
||||
<!-- Instance Card -->
|
||||
{@const instanceModelId = getInstanceModelId(instance)}
|
||||
{@const instanceQuantization =
|
||||
getInstanceQuantization(instance)}
|
||||
{@const instanceInfo = getInstanceInfo(instance)}
|
||||
{@const instanceConnections =
|
||||
getInstanceConnections(instance)}
|
||||
@@ -2960,7 +3030,10 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(instanceModelId);
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
}
|
||||
}}
|
||||
onkeydown={(e) => {
|
||||
@@ -2970,7 +3043,10 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(instanceModelId);
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -3061,7 +3137,9 @@
|
||||
<div
|
||||
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
|
||||
>
|
||||
{getInstanceModelId(instance)}
|
||||
{getInstanceModelId(instance)}{instanceQuantization
|
||||
? ` (${instanceQuantization}-bit)`
|
||||
: ""}
|
||||
</div>
|
||||
<div class="text-white/60 text-xs font-mono">
|
||||
Strategy: <span class="text-white/80"
|
||||
|
||||
@@ -141,14 +141,19 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
async def resolve_model_card(
|
||||
model_id: ModelId, quantization: int | None = None
|
||||
) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == ModelId(model_id):
|
||||
return card
|
||||
if quantization is None and card.quantization is None:
|
||||
return card
|
||||
if card.quantization == quantization:
|
||||
return card
|
||||
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
@@ -354,7 +359,7 @@ class API:
|
||||
model_id: ModelId,
|
||||
node_ids: Annotated[list[NodeId] | None, Query()] = None,
|
||||
) -> PlacementPreviewResponse:
|
||||
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
|
||||
seen: set[tuple[ModelId, int | None, Sharding, InstanceMeta, int]] = set()
|
||||
previews: list[PlacementPreview] = []
|
||||
required_nodes = set(node_ids) if node_ids else None
|
||||
|
||||
@@ -396,17 +401,32 @@ class API:
|
||||
required_nodes=required_nodes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
if (
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
current_ids = set(self.state.instances.keys())
|
||||
@@ -417,17 +437,32 @@ class API:
|
||||
]
|
||||
|
||||
if len(new_instances) != 1:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
if (
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error="Expected exactly one new instance from placement",
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
instance = new_instances[0]
|
||||
@@ -447,6 +482,7 @@ class API:
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
@@ -454,6 +490,7 @@ class API:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
@@ -464,6 +501,7 @@ class API:
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
@@ -669,6 +707,18 @@ class API:
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
)
|
||||
|
||||
def _has_matching_instance(
|
||||
self, model_id: ModelId, quantization: int | None
|
||||
) -> bool:
|
||||
"""Check if there's a running instance matching the model_id and quantization."""
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id == model_id
|
||||
and instance.shard_assignments.quantization == quantization
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def chat_completions(
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
@@ -720,22 +770,23 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
async def _validate_image_model(
|
||||
self, model: str, quantization: int | None = None
|
||||
) -> tuple[ModelId, int | None]:
|
||||
"""Validate model exists and return resolved model ID and quantization.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
Returns tuple of (model_id, quantization).
|
||||
"""
|
||||
model_card = await resolve_model_card(ModelId(model))
|
||||
model_card = await resolve_model_card(ModelId(model), quantization)
|
||||
resolved_model = model_card.model_id
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
resolved_quant = model_card.quantization
|
||||
if not self._has_matching_instance(ModelId(resolved_model), resolved_quant):
|
||||
await self._trigger_notify_user_to_download_model(resolved_model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {resolved_model}"
|
||||
)
|
||||
return resolved_model
|
||||
return resolved_model, resolved_quant
|
||||
|
||||
async def get_image(self, image_id: str) -> FileResponse:
|
||||
stored = self._image_store.get(Id(image_id))
|
||||
@@ -771,7 +822,9 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload.model, payload.quantization = await self._validate_image_model(
|
||||
payload.model, payload.quantization
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1016,7 +1069,9 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload.model, payload.quantization = await self._validate_image_model(
|
||||
payload.model, payload.quantization
|
||||
)
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
@@ -1038,6 +1093,7 @@ class API:
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: str,
|
||||
quantization: int | None,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
@@ -1050,7 +1106,9 @@ class API:
|
||||
advanced_params: AdvancedImageParams | None,
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
resolved_model, resolved_quant = await self._validate_image_model(
|
||||
model, quantization
|
||||
)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
@@ -1069,6 +1127,7 @@ class API:
|
||||
total_input_chunks=total_chunks,
|
||||
prompt=prompt,
|
||||
model=resolved_model,
|
||||
quantization=resolved_quant,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1107,6 +1166,7 @@ class API:
|
||||
image: UploadFile = File(...), # noqa: B008
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
quantization: str | None = Form(None),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
@@ -1121,6 +1181,9 @@ class API:
|
||||
# Parse string form values to proper types
|
||||
stream_bool = stream.lower() in ("true", "1", "yes")
|
||||
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
|
||||
quantization_int = (
|
||||
int(quantization) if quantization and quantization.isdigit() else None
|
||||
)
|
||||
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
@@ -1133,6 +1196,7 @@ class API:
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
quantization=quantization_int,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1169,6 +1233,7 @@ class API:
|
||||
image: UploadFile = File(...), # noqa: B008
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
quantization: str | None = Form(None),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
@@ -1178,6 +1243,10 @@ class API:
|
||||
advanced_params: str | None = Form(None),
|
||||
) -> BenchImageGenerationResponse:
|
||||
"""Handle benchmark image editing requests with generation stats."""
|
||||
quantization_int = (
|
||||
int(quantization) if quantization and quantization.isdigit() else None
|
||||
)
|
||||
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
with contextlib.suppress(Exception):
|
||||
@@ -1189,6 +1258,7 @@ class API:
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
quantization=quantization_int,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1223,6 +1293,7 @@ class API:
|
||||
data=[
|
||||
ModelListModel(
|
||||
id=card.model_id,
|
||||
quantization=card.quantization,
|
||||
hugging_face_id=card.model_id,
|
||||
name=card.model_id.short(),
|
||||
description="",
|
||||
|
||||
@@ -137,6 +137,7 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
@@ -170,6 +171,7 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
@@ -40,6 +40,7 @@ class ModelCard(CamelCaseModel):
|
||||
supports_tensor: bool
|
||||
tasks: list[ModelTask]
|
||||
components: list[ComponentInfo] | None = None
|
||||
quantization: int | None = None
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -413,7 +414,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
@@ -428,7 +429,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -442,7 +443,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -470,7 +471,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -484,7 +485,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -543,7 +544,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
@@ -551,10 +552,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -577,7 +578,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
@@ -585,10 +586,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -610,6 +611,91 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _create_image_model_quant_variants(
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
"""Create quantized variants of an image model card.
|
||||
|
||||
Only the transformer component is quantized; text encoders stay at bf16.
|
||||
Sizes are calculated exactly from the base card's component sizes.
|
||||
"""
|
||||
if base_card.components is None:
|
||||
raise ValueError(f"Image model {base_name} must have components defined")
|
||||
|
||||
quantizations = [8, 6, 5, 4, 3]
|
||||
|
||||
num_transformer_bytes = next(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name == "transformer"
|
||||
)
|
||||
|
||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
||||
|
||||
remaining_bytes = Memory.from_bytes(
|
||||
sum(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name != "transformer"
|
||||
)
|
||||
)
|
||||
|
||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
||||
assert base_card.components is not None
|
||||
return [
|
||||
ComponentInfo(
|
||||
component_name=c.component_name,
|
||||
component_path=c.component_path,
|
||||
storage_size=new_size
|
||||
if c.component_name == "transformer"
|
||||
else c.storage_size,
|
||||
n_layers=c.n_layers,
|
||||
can_shard=c.can_shard,
|
||||
safetensors_index_filename=c.safetensors_index_filename,
|
||||
)
|
||||
for c in base_card.components
|
||||
]
|
||||
|
||||
variants = {
|
||||
base_name: ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=transformer_bytes + remaining_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
quantization=None,
|
||||
)
|
||||
}
|
||||
|
||||
for quant in quantizations:
|
||||
quant_transformer_bytes = Memory.from_bytes(
|
||||
(num_transformer_bytes * quant) // 16
|
||||
)
|
||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
||||
|
||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=total_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
quantization=quant,
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
_image_model_cards: dict[str, ModelCard] = {}
|
||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
||||
_image_model_cards |= _create_image_model_quant_variants(_base_name, _base_card)
|
||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class ErrorResponse(BaseModel):
|
||||
|
||||
class ModelListModel(BaseModel):
|
||||
id: str
|
||||
quantization: int | None = None
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "exo"
|
||||
@@ -216,6 +217,7 @@ class CreateInstanceParams(BaseModel):
|
||||
|
||||
class PlacementPreview(BaseModel):
|
||||
model_id: ModelId
|
||||
quantization: int | None = None
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
instance: Instance | None = None
|
||||
@@ -267,6 +269,7 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
style: str | None = "vivid"
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
quantization: int | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@@ -292,6 +295,7 @@ class ImageEditsTaskParams(BaseModel):
|
||||
stream: bool | None = False
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
quantization: int | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@@ -303,6 +307,7 @@ class ImageEditsInternalParams(BaseModel):
|
||||
total_input_chunks: int = 0
|
||||
prompt: str
|
||||
model: str
|
||||
quantization: int | None = None
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
|
||||
@@ -82,6 +82,7 @@ RunnerStatus = (
|
||||
|
||||
class ShardAssignments(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
quantization: int | None = None
|
||||
runner_to_shard: Mapping[RunnerId, ShardMetadata]
|
||||
node_to_runner: Mapping[NodeId, RunnerId]
|
||||
|
||||
|
||||
@@ -71,8 +71,10 @@ class DistributedImageModel:
|
||||
def from_bound_instance(
|
||||
cls, bound_instance: BoundInstance
|
||||
) -> "DistributedImageModel":
|
||||
model_id = bound_instance.bound_shard.model_card.model_id
|
||||
model_card = bound_instance.bound_shard.model_card
|
||||
model_id = model_card.model_id
|
||||
model_path = build_model_path(model_id)
|
||||
quantize = model_card.quantization
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
@@ -93,6 +95,7 @@ class DistributedImageModel:
|
||||
local_path=model_path,
|
||||
shard_metadata=shard_metadata,
|
||||
group=group,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
|
||||
|
||||
Reference in New Issue
Block a user