mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-24 13:59:48 -05:00
Compare commits
7 Commits
ciaran/ima
...
v1.0.65
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
281aaeb013 | ||
|
|
10fdc439a5 | ||
|
|
78a8c06d57 | ||
|
|
4c0c6dcae9 | ||
|
|
d885600a4c | ||
|
|
55b67e2be2 | ||
|
|
30cfad9b68 |
@@ -18,6 +18,9 @@ enum NetworkSetupHelper {
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Wait for macOS to finish network setup after boot
|
||||
sleep 30
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
@@ -80,7 +83,7 @@ enum NetworkSetupHelper {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
|
||||
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Install")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
|
||||
@@ -112,23 +112,18 @@
|
||||
|
||||
// Extract available models from running instances
|
||||
const availableModels = $derived(() => {
|
||||
const models: Array<{
|
||||
id: string;
|
||||
quantization: number | null;
|
||||
label: string;
|
||||
isImageModel: boolean;
|
||||
}> = [];
|
||||
const models: Array<{ id: string; label: string; isImageModel: boolean }> =
|
||||
[];
|
||||
for (const [, instance] of Object.entries(instanceData)) {
|
||||
const { modelId, quantization } = getInstanceModelInfo(instance);
|
||||
const modelId = getInstanceModelId(instance);
|
||||
if (
|
||||
modelId &&
|
||||
modelId !== "Unknown" &&
|
||||
!models.some((m) => m.id === modelId && m.quantization === quantization)
|
||||
!models.some((m) => m.id === modelId)
|
||||
) {
|
||||
models.push({
|
||||
id: modelId,
|
||||
quantization,
|
||||
label: `${modelId.split("/").pop() || modelId}${quantization ? ` (${quantization}-bit)` : ""}`,
|
||||
label: modelId.split("/").pop() || modelId,
|
||||
isImageModel: modelSupportsImageGeneration(modelId),
|
||||
});
|
||||
}
|
||||
@@ -150,20 +145,20 @@
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id, models[0].quantization);
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some((m) => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id, models[0].quantization);
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id, newModels[0].quantization);
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel("", null);
|
||||
setSelectedChatModel("");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,23 +166,16 @@
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelInfo(instanceWrapped: unknown): {
|
||||
modelId: string;
|
||||
quantization: number | null;
|
||||
} {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object")
|
||||
return { modelId: "", quantization: null };
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object") return "";
|
||||
const keys = Object.keys(instanceWrapped as Record<string, unknown>);
|
||||
if (keys.length === 1) {
|
||||
const instance = (instanceWrapped as Record<string, unknown>)[
|
||||
keys[0]
|
||||
] as { shardAssignments?: { modelId?: string; quantization?: number } };
|
||||
return {
|
||||
modelId: instance?.shardAssignments?.modelId || "",
|
||||
quantization: instance?.shardAssignments?.quantization ?? null,
|
||||
};
|
||||
] as { shardAssignments?: { modelId?: string } };
|
||||
return instance?.shardAssignments?.modelId || "";
|
||||
}
|
||||
return { modelId: "", quantization: null };
|
||||
return "";
|
||||
}
|
||||
|
||||
async function handleFiles(files: File[]) {
|
||||
@@ -481,7 +469,7 @@
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
setSelectedChatModel(model.id, model.quantization);
|
||||
setSelectedChatModel(model.id);
|
||||
isModelDropdownOpen = false;
|
||||
}}
|
||||
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {currentModel ===
|
||||
|
||||
@@ -142,15 +142,11 @@
|
||||
return null;
|
||||
}
|
||||
|
||||
function formatModelName(
|
||||
modelId: string | null | undefined,
|
||||
quantization: number | null | undefined,
|
||||
): string {
|
||||
function formatModelName(modelId: string | null | undefined): string {
|
||||
if (!modelId) return "Unknown Model";
|
||||
const parts = modelId.split("/");
|
||||
const tail = parts[parts.length - 1] || modelId;
|
||||
const baseName = tail || modelId;
|
||||
return quantization ? `${baseName} (${quantization}-bit)` : baseName;
|
||||
return tail || modelId;
|
||||
}
|
||||
|
||||
function formatStrategy(
|
||||
@@ -248,7 +244,7 @@
|
||||
conversation.instanceType ?? instanceDetails.instanceType;
|
||||
|
||||
return {
|
||||
modelLabel: formatModelName(displayModel, conversation.quantization),
|
||||
modelLabel: formatModelName(displayModel),
|
||||
strategyLabel: formatStrategy(sharding, instanceType),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -162,7 +162,6 @@ export interface ModelDownloadStatus {
|
||||
// Placement preview from the API
|
||||
export interface PlacementPreview {
|
||||
model_id: string;
|
||||
quantization: number | null; // quantization bits or null for base model
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
instance: unknown | null;
|
||||
@@ -228,7 +227,6 @@ export interface Conversation {
|
||||
createdAt: number;
|
||||
updatedAt: number;
|
||||
modelId: string | null;
|
||||
quantization: number | null;
|
||||
sharding: string | null;
|
||||
instanceType: string | null;
|
||||
}
|
||||
@@ -493,7 +491,6 @@ class AppStore {
|
||||
createdAt: conversation.createdAt ?? Date.now(),
|
||||
updatedAt: conversation.updatedAt ?? Date.now(),
|
||||
modelId: conversation.modelId ?? null,
|
||||
quantization: conversation.quantization ?? null,
|
||||
sharding: conversation.sharding ?? null,
|
||||
instanceType: conversation.instanceType ?? null,
|
||||
}));
|
||||
@@ -672,7 +669,6 @@ class AppStore {
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
modelId: derivedModelId,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
sharding: derivedSharding,
|
||||
instanceType: derivedInstanceType,
|
||||
};
|
||||
@@ -1480,7 +1476,6 @@ class AppStore {
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
}),
|
||||
@@ -1566,17 +1561,11 @@ class AppStore {
|
||||
*/
|
||||
selectedChatModel = $state("");
|
||||
|
||||
/**
|
||||
* Selected model's quantization (null for base/unquantized models)
|
||||
*/
|
||||
selectedChatModelQuantization = $state<number | null>(null);
|
||||
|
||||
/**
|
||||
* Set the model to use for chat
|
||||
*/
|
||||
setSelectedModel(modelId: string, quantization: number | null = null) {
|
||||
setSelectedModel(modelId: string) {
|
||||
this.selectedChatModel = modelId;
|
||||
this.selectedChatModelQuantization = quantization;
|
||||
// Clear stats when model changes
|
||||
this.ttftMs = null;
|
||||
this.tps = null;
|
||||
@@ -1887,7 +1876,6 @@ class AppStore {
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
@@ -2070,7 +2058,6 @@ class AppStore {
|
||||
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
prompt,
|
||||
n: params.numImages,
|
||||
quality: params.quality,
|
||||
@@ -2297,12 +2284,6 @@ class AppStore {
|
||||
// Build FormData request
|
||||
const formData = new FormData();
|
||||
formData.append("model", model);
|
||||
if (this.selectedChatModelQuantization !== null) {
|
||||
formData.append(
|
||||
"quantization",
|
||||
this.selectedChatModelQuantization.toString(),
|
||||
);
|
||||
}
|
||||
formData.append("prompt", prompt);
|
||||
formData.append("image", imageBlob, "image.png");
|
||||
|
||||
@@ -2532,8 +2513,6 @@ export const isLoadingPreviews = () => appStore.isLoadingPreviews;
|
||||
export const lastUpdate = () => appStore.lastUpdate;
|
||||
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
|
||||
export const selectedChatModel = () => appStore.selectedChatModel;
|
||||
export const selectedChatModelQuantization = () =>
|
||||
appStore.selectedChatModelQuantization;
|
||||
export const debugMode = () => appStore.getDebugMode();
|
||||
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
@@ -2562,10 +2541,8 @@ export const setEditingImage = (imageDataUrl: string, sourceMessage: Message) =>
|
||||
appStore.setEditingImage(imageDataUrl, sourceMessage);
|
||||
export const clearEditingImage = () => appStore.clearEditingImage();
|
||||
export const clearChat = () => appStore.clearChat();
|
||||
export const setSelectedChatModel = (
|
||||
modelId: string,
|
||||
quantization: number | null = null,
|
||||
) => appStore.setSelectedModel(modelId, quantization);
|
||||
export const setSelectedChatModel = (modelId: string) =>
|
||||
appStore.setSelectedModel(modelId);
|
||||
export const selectPreviewModel = (modelId: string | null) =>
|
||||
appStore.selectPreviewModel(modelId);
|
||||
export const togglePreviewNodeFilter = (nodeId: string) =>
|
||||
|
||||
@@ -96,7 +96,6 @@
|
||||
let models = $state<
|
||||
Array<{
|
||||
id: string;
|
||||
quantization?: number | null;
|
||||
name?: string;
|
||||
storage_size_megabytes?: number;
|
||||
tasks?: string[];
|
||||
@@ -104,38 +103,12 @@
|
||||
}>
|
||||
>([]);
|
||||
|
||||
// Helper to get unique model key (combines id + quantization)
|
||||
function getModelKey(model: {
|
||||
id: string;
|
||||
quantization?: number | null;
|
||||
}): string {
|
||||
return model.quantization != null
|
||||
? `${model.id}-q${model.quantization}`
|
||||
: model.id;
|
||||
}
|
||||
|
||||
// Helper to get display name with quantization suffix
|
||||
function getModelDisplayName(model: {
|
||||
id: string;
|
||||
name?: string;
|
||||
quantization?: number | null;
|
||||
}): string {
|
||||
const baseName = model.name || model.id;
|
||||
if (model.quantization != null) {
|
||||
return `${baseName} (${model.quantization}-bit)`;
|
||||
}
|
||||
return baseName;
|
||||
}
|
||||
|
||||
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
|
||||
const modelTasks = $derived(() => {
|
||||
const tasks: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.tasks && model.tasks.length > 0) {
|
||||
// Map by unique key (model_id + quantization)
|
||||
const key = getModelKey(model);
|
||||
tasks[key] = model.tasks;
|
||||
// Also map by short ID (for backwards compatibility)
|
||||
// Map by short ID
|
||||
tasks[model.id] = model.tasks;
|
||||
// Also map by hugging_face_id from the API response
|
||||
if (model.hugging_face_id) {
|
||||
@@ -173,7 +146,6 @@
|
||||
const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults";
|
||||
interface LaunchDefaults {
|
||||
modelId: string | null;
|
||||
quantization: number | null;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
@@ -182,7 +154,6 @@
|
||||
function saveLaunchDefaults(): void {
|
||||
const defaults: LaunchDefaults = {
|
||||
modelId: selectedPreviewModelId(),
|
||||
quantization: selectedQuantization,
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
@@ -206,7 +177,7 @@
|
||||
}
|
||||
|
||||
function applyLaunchDefaults(
|
||||
availableModels: Array<{ id: string; quantization?: number | null }>,
|
||||
availableModels: Array<{ id: string }>,
|
||||
maxNodes: number,
|
||||
): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
@@ -225,17 +196,12 @@
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
// Only apply model if it exists in the available models (matching both id and quantization)
|
||||
if (defaults.modelId) {
|
||||
const matchingModel = availableModels.find(
|
||||
(m) =>
|
||||
m.id === defaults.modelId &&
|
||||
(m.quantization ?? null) === (defaults.quantization ?? null),
|
||||
);
|
||||
if (matchingModel) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
selectedQuantization = defaults.quantization ?? null;
|
||||
}
|
||||
// Only apply model if it exists in the available models
|
||||
if (
|
||||
defaults.modelId &&
|
||||
availableModels.some((m) => m.id === defaults.modelId)
|
||||
) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,7 +209,6 @@
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let minNodesInitialized = $state(false);
|
||||
let launchingModelId = $state<string | null>(null);
|
||||
let selectedQuantization = $state<number | null>(null);
|
||||
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
|
||||
// Custom dropdown state
|
||||
@@ -504,40 +469,39 @@
|
||||
if (models.length === 0) return tags;
|
||||
|
||||
// Find the fastest model (highest TPS)
|
||||
let fastestKey = "";
|
||||
let fastestId = "";
|
||||
let highestTps = 0;
|
||||
|
||||
// Find the biggest model (most memory)
|
||||
let biggestKey = "";
|
||||
let biggestId = "";
|
||||
let highestMemory = 0;
|
||||
|
||||
for (const model of models) {
|
||||
const key = getModelKey(model);
|
||||
const perf = estimatePerformance(model.id);
|
||||
const mem = getModelSizeGB(model);
|
||||
|
||||
if (perf.tps > highestTps) {
|
||||
highestTps = perf.tps;
|
||||
fastestKey = key;
|
||||
fastestId = model.id;
|
||||
}
|
||||
|
||||
if (mem > highestMemory) {
|
||||
highestMemory = mem;
|
||||
biggestKey = key;
|
||||
biggestId = model.id;
|
||||
}
|
||||
}
|
||||
|
||||
if (fastestKey) {
|
||||
tags[fastestKey] = tags[fastestKey] || [];
|
||||
tags[fastestKey].push("FASTEST");
|
||||
if (fastestId) {
|
||||
tags[fastestId] = tags[fastestId] || [];
|
||||
tags[fastestId].push("FASTEST");
|
||||
}
|
||||
|
||||
if (biggestKey && biggestKey !== fastestKey) {
|
||||
tags[biggestKey] = tags[biggestKey] || [];
|
||||
tags[biggestKey].push("BIGGEST");
|
||||
} else if (biggestKey === fastestKey && biggestKey) {
|
||||
if (biggestId && biggestId !== fastestId) {
|
||||
tags[biggestId] = tags[biggestId] || [];
|
||||
tags[biggestId].push("BIGGEST");
|
||||
} else if (biggestId === fastestId && biggestId) {
|
||||
// Same model is both - unlikely but handle it
|
||||
tags[biggestKey].push("BIGGEST");
|
||||
tags[biggestId].push("BIGGEST");
|
||||
}
|
||||
|
||||
return tags;
|
||||
@@ -567,13 +531,12 @@
|
||||
}
|
||||
|
||||
async function launchInstance(
|
||||
model: { id: string; quantization?: number | null },
|
||||
modelId: string,
|
||||
specificPreview?: PlacementPreview | null,
|
||||
) {
|
||||
const modelKey = getModelKey(model);
|
||||
if (!model.id || launchingModelId) return;
|
||||
if (!modelId || launchingModelId) return;
|
||||
|
||||
launchingModelId = modelKey;
|
||||
launchingModelId = modelId;
|
||||
|
||||
try {
|
||||
// Use the specific preview if provided, otherwise fall back to filtered preview
|
||||
@@ -587,7 +550,7 @@
|
||||
} else {
|
||||
// Fallback: GET placement from API
|
||||
const placementResponse = await fetch(
|
||||
`/instance/placement?model_id=${encodeURIComponent(model.id)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
|
||||
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
|
||||
);
|
||||
|
||||
if (!placementResponse.ok) {
|
||||
@@ -611,7 +574,7 @@
|
||||
console.error("Failed to launch instance:", errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(model.id, model.quantization ?? null);
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
@@ -1111,20 +1074,19 @@
|
||||
const [, lastInstance] =
|
||||
remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
const newQuantization = getInstanceQuantization(lastInstance);
|
||||
if (
|
||||
newModelId &&
|
||||
newModelId !== "Unknown" &&
|
||||
newModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(newModelId, newQuantization);
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel("", null);
|
||||
setSelectedChatModel("");
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel("", null);
|
||||
setSelectedChatModel("");
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -1150,16 +1112,6 @@
|
||||
return inst.shardAssignments?.modelId || "Unknown Model";
|
||||
}
|
||||
|
||||
// Get quantization from an instance
|
||||
function getInstanceQuantization(instanceWrapped: unknown): number | null {
|
||||
const [, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== "object") return null;
|
||||
const inst = instance as {
|
||||
shardAssignments?: { quantization?: number | null };
|
||||
};
|
||||
return inst.shardAssignments?.quantization ?? null;
|
||||
}
|
||||
|
||||
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
@@ -1581,16 +1533,15 @@
|
||||
|
||||
// Get ALL filtered previews based on current settings (matching minimum nodes)
|
||||
// Note: previewsData already contains previews for the selected model (fetched via API)
|
||||
// Backend handles node_ids filtering, we filter by sharding/instance type, quantization, and min nodes
|
||||
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
|
||||
const filteredPreviews = $derived(() => {
|
||||
if (!selectedModelId || previewsData.length === 0) return [];
|
||||
|
||||
// Find previews matching sharding/instance type and quantization
|
||||
// Find previews matching sharding/instance type (model_id filter not needed since previewsData is already for selected model)
|
||||
const matchingPreviews = previewsData.filter(
|
||||
(p: PlacementPreview) =>
|
||||
p.sharding === selectedSharding &&
|
||||
matchesSelectedRuntime(p.instance_meta) &&
|
||||
p.quantization === selectedQuantization &&
|
||||
p.error === null &&
|
||||
p.memory_delta_by_node !== null,
|
||||
);
|
||||
@@ -1997,8 +1948,6 @@
|
||||
{@const isRunning = statusText === "RUNNING"}
|
||||
<!-- Instance Card -->
|
||||
{@const instanceModelId = getInstanceModelId(instance)}
|
||||
{@const instanceQuantization =
|
||||
getInstanceQuantization(instance)}
|
||||
{@const instanceInfo = getInstanceInfo(instance)}
|
||||
{@const instanceConnections =
|
||||
getInstanceConnections(instance)}
|
||||
@@ -2014,10 +1963,7 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
setSelectedChatModel(instanceModelId);
|
||||
}
|
||||
}}
|
||||
onkeydown={(e) => {
|
||||
@@ -2027,10 +1973,7 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
setSelectedChatModel(instanceModelId);
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -2121,9 +2064,7 @@
|
||||
<div
|
||||
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
|
||||
>
|
||||
{getInstanceModelId(instance)}{instanceQuantization
|
||||
? ` (${instanceQuantization}-bit)`
|
||||
: ""}
|
||||
{getInstanceModelId(instance)}
|
||||
</div>
|
||||
<div class="text-white/60 text-xs font-mono">
|
||||
Strategy: <span class="text-white/80"
|
||||
@@ -2430,9 +2371,7 @@
|
||||
>
|
||||
{#if selectedModelId}
|
||||
{@const foundModel = models.find(
|
||||
(m) =>
|
||||
m.id === selectedModelId &&
|
||||
(m.quantization ?? null) === selectedQuantization,
|
||||
(m) => m.id === selectedModelId,
|
||||
)}
|
||||
{#if foundModel}
|
||||
{@const sizeGB = getModelSizeGB(foundModel)}
|
||||
@@ -2485,7 +2424,7 @@
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate"
|
||||
>{getModelDisplayName(foundModel)}</span
|
||||
>{foundModel.name || foundModel.id}</span
|
||||
>
|
||||
</span>
|
||||
<span class="text-white/50 text-xs flex-shrink-0"
|
||||
@@ -2564,7 +2503,6 @@
|
||||
onclick={() => {
|
||||
if (modelCanFit) {
|
||||
selectPreviewModel(model.id);
|
||||
selectedQuantization = model.quantization ?? null;
|
||||
saveLaunchDefaults();
|
||||
isModelDropdownOpen = false;
|
||||
modelDropdownSearch = "";
|
||||
@@ -2572,8 +2510,7 @@
|
||||
}}
|
||||
disabled={!modelCanFit}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedModelId ===
|
||||
model.id &&
|
||||
selectedQuantization === (model.quantization ?? null)
|
||||
model.id
|
||||
? 'bg-transparent text-exo-yellow cursor-pointer'
|
||||
: modelCanFit
|
||||
? 'text-white/80 hover:text-exo-yellow cursor-pointer'
|
||||
@@ -2618,9 +2555,7 @@
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate"
|
||||
>{getModelDisplayName(model)}</span
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span
|
||||
class="flex-shrink-0 text-xs {modelCanFit
|
||||
@@ -2830,16 +2765,14 @@
|
||||
</div>
|
||||
{:else}
|
||||
{@const selectedModel = models.find(
|
||||
(m) =>
|
||||
m.id === selectedModelId &&
|
||||
(m.quantization ?? null) === selectedQuantization,
|
||||
(m) => m.id === selectedModelId,
|
||||
)}
|
||||
{@const allPreviews = filteredPreviews()}
|
||||
{#if selectedModel && allPreviews.length > 0}
|
||||
{@const downloadStatus = getModelDownloadStatus(
|
||||
selectedModel.id,
|
||||
)}
|
||||
{@const tags = modelTags()[getModelKey(selectedModel)] || []}
|
||||
{@const tags = modelTags()[selectedModel.id] || []}
|
||||
<div class="space-y-3">
|
||||
{#each allPreviews as apiPreview, i}
|
||||
<div
|
||||
@@ -2857,14 +2790,13 @@
|
||||
>
|
||||
<ModelCard
|
||||
model={selectedModel}
|
||||
isLaunching={launchingModelId ===
|
||||
getModelKey(selectedModel)}
|
||||
isLaunching={launchingModelId === selectedModel.id}
|
||||
{downloadStatus}
|
||||
nodes={data?.nodes ?? {}}
|
||||
sharding={apiPreview.sharding}
|
||||
runtime={apiPreview.instance_meta}
|
||||
onLaunch={() =>
|
||||
launchInstance(selectedModel, apiPreview)}
|
||||
launchInstance(selectedModel.id, apiPreview)}
|
||||
{tags}
|
||||
{apiPreview}
|
||||
modelIdOverride={apiPreview.model_id}
|
||||
@@ -3013,8 +2945,6 @@
|
||||
{@const isRunning = statusText === "RUNNING"}
|
||||
<!-- Instance Card -->
|
||||
{@const instanceModelId = getInstanceModelId(instance)}
|
||||
{@const instanceQuantization =
|
||||
getInstanceQuantization(instance)}
|
||||
{@const instanceInfo = getInstanceInfo(instance)}
|
||||
{@const instanceConnections =
|
||||
getInstanceConnections(instance)}
|
||||
@@ -3030,10 +2960,7 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
setSelectedChatModel(instanceModelId);
|
||||
}
|
||||
}}
|
||||
onkeydown={(e) => {
|
||||
@@ -3043,10 +2970,7 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
setSelectedChatModel(instanceModelId);
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -3137,9 +3061,7 @@
|
||||
<div
|
||||
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
|
||||
>
|
||||
{getInstanceModelId(instance)}{instanceQuantization
|
||||
? ` (${instanceQuantization}-bit)`
|
||||
: ""}
|
||||
{getInstanceModelId(instance)}
|
||||
</div>
|
||||
<div class="text-white/60 text-xs font-mono">
|
||||
Strategy: <span class="text-white/80"
|
||||
|
||||
@@ -17,7 +17,7 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
|
||||
@@ -141,19 +141,14 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(
|
||||
model_id: ModelId, quantization: int | None = None
|
||||
) -> ModelCard:
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == ModelId(model_id):
|
||||
if quantization is None and card.quantization is None:
|
||||
return card
|
||||
if card.quantization == quantization:
|
||||
return card
|
||||
return card
|
||||
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
@@ -359,7 +354,7 @@ class API:
|
||||
model_id: ModelId,
|
||||
node_ids: Annotated[list[NodeId] | None, Query()] = None,
|
||||
) -> PlacementPreviewResponse:
|
||||
seen: set[tuple[ModelId, int | None, Sharding, InstanceMeta, int]] = set()
|
||||
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
|
||||
previews: list[PlacementPreview] = []
|
||||
required_nodes = set(node_ids) if node_ids else None
|
||||
|
||||
@@ -401,32 +396,17 @@ class API:
|
||||
required_nodes=required_nodes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
) not in seen:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
current_ids = set(self.state.instances.keys())
|
||||
@@ -437,32 +417,17 @@ class API:
|
||||
]
|
||||
|
||||
if len(new_instances) != 1:
|
||||
if (
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
) not in seen:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error="Expected exactly one new instance from placement",
|
||||
)
|
||||
)
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
instance = new_instances[0]
|
||||
@@ -482,7 +447,6 @@ class API:
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
@@ -490,7 +454,6 @@ class API:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
@@ -501,7 +464,6 @@ class API:
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
@@ -707,18 +669,6 @@ class API:
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
)
|
||||
|
||||
def _has_matching_instance(
|
||||
self, model_id: ModelId, quantization: int | None
|
||||
) -> bool:
|
||||
"""Check if there's a running instance matching the model_id and quantization."""
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id == model_id
|
||||
and instance.shard_assignments.quantization == quantization
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def chat_completions(
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
@@ -770,23 +720,22 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(
|
||||
self, model: str, quantization: int | None = None
|
||||
) -> tuple[ModelId, int | None]:
|
||||
"""Validate model exists and return resolved model ID and quantization.
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
Returns tuple of (model_id, quantization).
|
||||
"""
|
||||
model_card = await resolve_model_card(ModelId(model), quantization)
|
||||
model_card = await resolve_model_card(ModelId(model))
|
||||
resolved_model = model_card.model_id
|
||||
resolved_quant = model_card.quantization
|
||||
if not self._has_matching_instance(ModelId(resolved_model), resolved_quant):
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(resolved_model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {resolved_model}"
|
||||
)
|
||||
return resolved_model, resolved_quant
|
||||
return resolved_model
|
||||
|
||||
async def get_image(self, image_id: str) -> FileResponse:
|
||||
stored = self._image_store.get(Id(image_id))
|
||||
@@ -822,9 +771,7 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model, payload.quantization = await self._validate_image_model(
|
||||
payload.model, payload.quantization
|
||||
)
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1069,9 +1016,7 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model, payload.quantization = await self._validate_image_model(
|
||||
payload.model, payload.quantization
|
||||
)
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
@@ -1093,7 +1038,6 @@ class API:
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: str,
|
||||
quantization: int | None,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
@@ -1106,9 +1050,7 @@ class API:
|
||||
advanced_params: AdvancedImageParams | None,
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model, resolved_quant = await self._validate_image_model(
|
||||
model, quantization
|
||||
)
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
@@ -1127,7 +1069,6 @@ class API:
|
||||
total_input_chunks=total_chunks,
|
||||
prompt=prompt,
|
||||
model=resolved_model,
|
||||
quantization=resolved_quant,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1166,7 +1107,6 @@ class API:
|
||||
image: UploadFile = File(...), # noqa: B008
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
quantization: str | None = Form(None),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
@@ -1181,9 +1121,6 @@ class API:
|
||||
# Parse string form values to proper types
|
||||
stream_bool = stream.lower() in ("true", "1", "yes")
|
||||
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
|
||||
quantization_int = (
|
||||
int(quantization) if quantization and quantization.isdigit() else None
|
||||
)
|
||||
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
@@ -1196,7 +1133,6 @@ class API:
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
quantization=quantization_int,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1233,7 +1169,6 @@ class API:
|
||||
image: UploadFile = File(...), # noqa: B008
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
quantization: str | None = Form(None),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
@@ -1243,10 +1178,6 @@ class API:
|
||||
advanced_params: str | None = Form(None),
|
||||
) -> BenchImageGenerationResponse:
|
||||
"""Handle benchmark image editing requests with generation stats."""
|
||||
quantization_int = (
|
||||
int(quantization) if quantization and quantization.isdigit() else None
|
||||
)
|
||||
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
with contextlib.suppress(Exception):
|
||||
@@ -1258,7 +1189,6 @@ class API:
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
quantization=quantization_int,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1293,7 +1223,6 @@ class API:
|
||||
data=[
|
||||
ModelListModel(
|
||||
id=card.model_id,
|
||||
quantization=card.quantization,
|
||||
hugging_face_id=card.model_id,
|
||||
name=card.model_id.short(),
|
||||
description="",
|
||||
|
||||
@@ -137,7 +137,6 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
@@ -171,7 +170,6 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,6 @@ class ModelCard(CamelCaseModel):
|
||||
supports_tensor: bool
|
||||
tasks: list[ModelTask]
|
||||
components: list[ComponentInfo] | None = None
|
||||
quantization: int | None = None
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -414,7 +413,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
@@ -429,7 +428,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -443,7 +442,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57,
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -471,7 +470,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -485,7 +484,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57,
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -544,7 +543,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60,
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
@@ -552,10 +551,10 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -578,7 +577,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60,
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
@@ -586,10 +585,10 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -611,91 +610,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _create_image_model_quant_variants(
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
"""Create quantized variants of an image model card.
|
||||
|
||||
Only the transformer component is quantized; text encoders stay at bf16.
|
||||
Sizes are calculated exactly from the base card's component sizes.
|
||||
"""
|
||||
if base_card.components is None:
|
||||
raise ValueError(f"Image model {base_name} must have components defined")
|
||||
|
||||
quantizations = [8, 6, 5, 4, 3]
|
||||
|
||||
num_transformer_bytes = next(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name == "transformer"
|
||||
)
|
||||
|
||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
||||
|
||||
remaining_bytes = Memory.from_bytes(
|
||||
sum(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name != "transformer"
|
||||
)
|
||||
)
|
||||
|
||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
||||
assert base_card.components is not None
|
||||
return [
|
||||
ComponentInfo(
|
||||
component_name=c.component_name,
|
||||
component_path=c.component_path,
|
||||
storage_size=new_size
|
||||
if c.component_name == "transformer"
|
||||
else c.storage_size,
|
||||
n_layers=c.n_layers,
|
||||
can_shard=c.can_shard,
|
||||
safetensors_index_filename=c.safetensors_index_filename,
|
||||
)
|
||||
for c in base_card.components
|
||||
]
|
||||
|
||||
variants = {
|
||||
base_name: ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=transformer_bytes + remaining_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
quantization=None,
|
||||
)
|
||||
}
|
||||
|
||||
for quant in quantizations:
|
||||
quant_transformer_bytes = Memory.from_bytes(
|
||||
(num_transformer_bytes * quant) // 16
|
||||
)
|
||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
||||
|
||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=total_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
quantization=quant,
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
_image_model_cards: dict[str, ModelCard] = {}
|
||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
||||
_image_model_cards |= _create_image_model_quant_variants(_base_name, _base_card)
|
||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ class ErrorResponse(BaseModel):
|
||||
|
||||
class ModelListModel(BaseModel):
|
||||
id: str
|
||||
quantization: int | None = None
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "exo"
|
||||
@@ -217,7 +216,6 @@ class CreateInstanceParams(BaseModel):
|
||||
|
||||
class PlacementPreview(BaseModel):
|
||||
model_id: ModelId
|
||||
quantization: int | None = None
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
instance: Instance | None = None
|
||||
@@ -269,7 +267,6 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
style: str | None = "vivid"
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
quantization: int | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@@ -295,7 +292,6 @@ class ImageEditsTaskParams(BaseModel):
|
||||
stream: bool | None = False
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
quantization: int | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@@ -307,7 +303,6 @@ class ImageEditsInternalParams(BaseModel):
|
||||
total_input_chunks: int = 0
|
||||
prompt: str
|
||||
model: str
|
||||
quantization: int | None = None
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
|
||||
@@ -82,7 +82,6 @@ RunnerStatus = (
|
||||
|
||||
class ShardAssignments(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
quantization: int | None = None
|
||||
runner_to_shard: Mapping[RunnerId, ShardMetadata]
|
||||
node_to_runner: Mapping[NodeId, RunnerId]
|
||||
|
||||
|
||||
@@ -71,10 +71,8 @@ class DistributedImageModel:
|
||||
def from_bound_instance(
|
||||
cls, bound_instance: BoundInstance
|
||||
) -> "DistributedImageModel":
|
||||
model_card = bound_instance.bound_shard.model_card
|
||||
model_id = model_card.model_id
|
||||
model_id = bound_instance.bound_shard.model_card.model_id
|
||||
model_path = build_model_path(model_id)
|
||||
quantize = model_card.quantization
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
@@ -95,7 +93,6 @@ class DistributedImageModel:
|
||||
local_path=model_path,
|
||||
shard_metadata=shard_metadata,
|
||||
group=group,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
|
||||
|
||||
@@ -145,6 +145,10 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -252,10 +256,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
|
||||
@@ -170,10 +170,10 @@ def mlx_distributed_init(
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
40
uv.lock
generated
40
uv.lock
generated
@@ -376,8 +376,8 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -412,8 +412,8 @@ requires-dist = [
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
|
||||
{ name = "mflux", specifier = "==0.15.4" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
|
||||
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
@@ -994,8 +994,8 @@ dependencies = [
|
||||
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1022,18 +1022,12 @@ wheels = [
|
||||
name = "mlx"
|
||||
version = "0.30.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux'",
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
|
||||
]
|
||||
@@ -1046,6 +1040,14 @@ cuda13 = [
|
||||
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.4.dev20260121+fbe306f9"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.3"
|
||||
@@ -1076,7 +1078,7 @@ version = "0.30.4"
|
||||
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1084,16 +1086,6 @@ dependencies = [
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.30.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "more-itertools"
|
||||
version = "10.8.0"
|
||||
|
||||
Reference in New Issue
Block a user