mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-31 01:01:11 -05:00
Compare commits
19 Commits
model-card
...
ciaran/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92de82bdf6 | ||
|
|
e792b19f5d | ||
|
|
edc3a88b12 | ||
|
|
0ae0c788d5 | ||
|
|
09abf44d49 | ||
|
|
ac4b78349b | ||
|
|
96c440c3b1 | ||
|
|
c14d63cf61 | ||
|
|
cd946742f7 | ||
|
|
a5bc38ad1f | ||
|
|
2a4e0d4629 | ||
|
|
46a14153dd | ||
|
|
9ba61f3733 | ||
|
|
d9eca75895 | ||
|
|
9dabde7e57 | ||
|
|
a31942ce12 | ||
|
|
7cc313b22a | ||
|
|
2837225dc7 | ||
|
|
e4c6a7dbb4 |
12
.github/actions/typecheck/action.yml
vendored
12
.github/actions/typecheck/action.yml
vendored
@@ -1,12 +0,0 @@
|
||||
name: Type Check
|
||||
|
||||
description: "Run type checker"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Run type checker
|
||||
run: |
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
|
||||
shell: bash
|
||||
82
.github/workflows/pipeline.yml
vendored
82
.github/workflows/pipeline.yml
vendored
@@ -26,73 +26,14 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Pull LFS files
|
||||
run: |
|
||||
echo "Pulling Git LFS files..."
|
||||
git lfs pull
|
||||
shell: bash
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix binary exists directly
|
||||
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
nix --version
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Found nix profile script, sourcing..."
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
nix --version
|
||||
elif command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
nix --version
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Configure basedpyright include for local MLX
|
||||
run: |
|
||||
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
|
||||
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
|
||||
awk '
|
||||
BEGIN { in=0 }
|
||||
/^\[tool\.basedpyright\]/ { in=1; print; next }
|
||||
in && /^\[/ { in=0 } # next section
|
||||
in && /^[ \t]*include[ \t]*=/ {
|
||||
print "include = [\"/Users/Shared/mlx\"]"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
|
||||
|
||||
echo "New [tool.basedpyright] section:"
|
||||
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
|
||||
else
|
||||
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
|
||||
fi
|
||||
else
|
||||
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
@@ -191,3 +132,14 @@ jobs:
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
|
||||
- name: Run pytest (macOS only)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
|
||||
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
|
||||
|
||||
# Run pytest outside sandbox (needs GPU access for MLX)
|
||||
export HOME="$RUNNER_TEMP"
|
||||
export EXO_TESTS=1
|
||||
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
|
||||
@@ -342,6 +342,8 @@
|
||||
SDKROOT = macosx;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
@@ -397,6 +399,8 @@
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = macosx;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
|
||||
10
dashboard/package-lock.json
generated
10
dashboard/package-lock.json
generated
@@ -865,7 +865,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -905,7 +904,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1522,7 +1520,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1532,7 +1529,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1945,7 +1941,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2653,7 +2648,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2696,7 +2690,6 @@
|
||||
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -2869,7 +2862,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -3014,7 +3006,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3036,7 +3027,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -173,6 +173,41 @@ export interface PlacementPreviewResponse {
|
||||
previews: PlacementPreview[];
|
||||
}
|
||||
|
||||
interface ImageApiResponse {
|
||||
created: number;
|
||||
data: Array<{ b64_json?: string; url?: string }>;
|
||||
}
|
||||
|
||||
// Trace API response types
|
||||
export interface TraceCategoryStats {
|
||||
totalUs: number;
|
||||
count: number;
|
||||
minUs: number;
|
||||
maxUs: number;
|
||||
avgUs: number;
|
||||
}
|
||||
|
||||
export interface TraceRankStats {
|
||||
byCategory: Record<string, TraceCategoryStats>;
|
||||
}
|
||||
|
||||
export interface TraceStatsResponse {
|
||||
taskId: string;
|
||||
totalWallTimeUs: number;
|
||||
byCategory: Record<string, TraceCategoryStats>;
|
||||
byRank: Record<number, TraceRankStats>;
|
||||
}
|
||||
|
||||
export interface TraceListItem {
|
||||
taskId: string;
|
||||
createdAt: string;
|
||||
fileSize: number;
|
||||
}
|
||||
|
||||
export interface TraceListResponse {
|
||||
traces: TraceListItem[];
|
||||
}
|
||||
|
||||
interface RawStateResponse {
|
||||
topology?: RawTopology;
|
||||
instances?: Record<
|
||||
@@ -2095,107 +2130,137 @@ class AppStore {
|
||||
throw new Error(`API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
// Streaming requires both stream=true AND partialImages > 0
|
||||
const isStreaming = params.stream && params.partialImages > 0;
|
||||
|
||||
interface ImageGenerationChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
image_index?: number;
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
if (!isStreaming) {
|
||||
// Non-streaming: parse JSON response directly
|
||||
const jsonResponse = (await response.json()) as ImageApiResponse;
|
||||
const format = params.outputFormat || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
|
||||
const numImages = params.numImages;
|
||||
const attachments: MessageAttachment[] = jsonResponse.data
|
||||
.filter((img) => img.b64_json)
|
||||
.map((img, index) => ({
|
||||
type: "generated-image" as const,
|
||||
name: `generated-image-${index + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${img.b64_json}`,
|
||||
mimeType,
|
||||
}));
|
||||
|
||||
await this.parseSSEStream<ImageGenerationChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = attachments;
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
} else {
|
||||
// Streaming mode: use SSE parser
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
interface ImageGenerationChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
image_index?: number;
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
const numImages = params.numImages;
|
||||
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
await this.parseSSEStream<ImageGenerationChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = progressText;
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
msg.attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
msg.attachments = [...finals, partialAttachment];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
msg.attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
msg.attachments = [...previousFinals, newAttachment];
|
||||
}
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
msg.content = "";
|
||||
}
|
||||
},
|
||||
);
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = progressText;
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
msg.attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
msg.attachments = [...finals, partialAttachment];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
msg.attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
msg.attachments = [...previousFinals, newAttachment];
|
||||
}
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
msg.content = "";
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error generating image:", error);
|
||||
this.handleStreamingError(
|
||||
@@ -2343,69 +2408,98 @@ class AppStore {
|
||||
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = apiResponse.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
// Streaming requires both stream=true AND partialImages > 0
|
||||
const isStreaming = params.stream && params.partialImages > 0;
|
||||
|
||||
interface ImageEditChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
if (!isStreaming) {
|
||||
// Non-streaming: parse JSON response directly
|
||||
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
|
||||
const format = params.outputFormat || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const attachments: MessageAttachment[] = jsonResponse.data
|
||||
.filter((img) => img.b64_json)
|
||||
.map((img) => ({
|
||||
type: "generated-image" as const,
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${img.b64_json}`,
|
||||
mimeType,
|
||||
}));
|
||||
|
||||
await this.parseSSEStream<ImageEditChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = attachments;
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
} else {
|
||||
// Streaming mode: use SSE parser
|
||||
const reader = apiResponse.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = `Editing... ${partialNum}/${totalPartials}`;
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
interface ImageEditChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ImageEditChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = `Editing... ${partialNum}/${totalPartials}`;
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
}
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error editing image:", error);
|
||||
this.handleStreamingError(
|
||||
@@ -2491,6 +2585,49 @@ class AppStore {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List all available traces
|
||||
*/
|
||||
async listTraces(): Promise<TraceListResponse> {
|
||||
const response = await fetch("/v1/traces");
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to list traces: ${response.status}`);
|
||||
}
|
||||
return (await response.json()) as TraceListResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a trace exists for a given task ID
|
||||
*/
|
||||
async checkTraceExists(taskId: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get computed statistics for a task's trace
|
||||
*/
|
||||
async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {
|
||||
const response = await fetch(
|
||||
`/v1/traces/${encodeURIComponent(taskId)}/stats`,
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch trace stats: ${response.status}`);
|
||||
}
|
||||
return (await response.json()) as TraceStatsResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the URL for the raw trace file (for Perfetto)
|
||||
*/
|
||||
getTraceRawUrl(taskId: string): string {
|
||||
return `/v1/traces/${encodeURIComponent(taskId)}/raw`;
|
||||
}
|
||||
}
|
||||
|
||||
export const appStore = new AppStore();
|
||||
@@ -2602,3 +2739,12 @@ export const startDownload = (nodeId: string, shardMetadata: object) =>
|
||||
appStore.startDownload(nodeId, shardMetadata);
|
||||
export const deleteDownload = (nodeId: string, modelId: string) =>
|
||||
appStore.deleteDownload(nodeId, modelId);
|
||||
|
||||
// Trace actions
|
||||
export const listTraces = () => appStore.listTraces();
|
||||
export const checkTraceExists = (taskId: string) =>
|
||||
appStore.checkTraceExists(taskId);
|
||||
export const fetchTraceStats = (taskId: string) =>
|
||||
appStore.fetchTraceStats(taskId);
|
||||
export const getTraceRawUrl = (taskId: string) =>
|
||||
appStore.getTraceRawUrl(taskId);
|
||||
|
||||
172
dashboard/src/routes/traces/+page.svelte
Normal file
172
dashboard/src/routes/traces/+page.svelte
Normal file
@@ -0,0 +1,172 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
listTraces,
|
||||
getTraceRawUrl,
|
||||
type TraceListItem,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
let traces = $state<TraceListItem[]>([]);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (!bytes || bytes <= 0) return "0B";
|
||||
const units = ["B", "KB", "MB", "GB"];
|
||||
const i = Math.min(
|
||||
Math.floor(Math.log(bytes) / Math.log(1024)),
|
||||
units.length - 1,
|
||||
);
|
||||
const val = bytes / Math.pow(1024, i);
|
||||
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
|
||||
}
|
||||
|
||||
function formatDate(isoString: string): string {
|
||||
const date = new Date(isoString);
|
||||
return date.toLocaleString();
|
||||
}
|
||||
|
||||
async function openInPerfetto(taskId: string) {
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
async function refresh() {
|
||||
loading = true;
|
||||
error = null;
|
||||
try {
|
||||
const response = await listTraces();
|
||||
traces = response.traces;
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load traces";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
refresh();
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Traces
|
||||
</h1>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
onclick={refresh}
|
||||
disabled={loading}
|
||||
>
|
||||
Refresh
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading traces...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if traces.length === 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
|
||||
>
|
||||
<div class="text-sm">No traces found.</div>
|
||||
<div class="text-xs text-exo-light-gray/70">
|
||||
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="space-y-3">
|
||||
{#each traces as trace}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
|
||||
>
|
||||
<div class="min-w-0 flex-1">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
|
||||
>
|
||||
{trace.taskId}
|
||||
</a>
|
||||
<div class="text-xs text-exo-light-gray font-mono mt-1">
|
||||
{formatDate(trace.createdAt)} • {formatBytes(
|
||||
trace.fileSize,
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 shrink-0">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
>
|
||||
View Stats
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
|
||||
onclick={() => openInPerfetto(trace.taskId)}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
347
dashboard/src/routes/traces/[taskId]/+page.svelte
Normal file
347
dashboard/src/routes/traces/[taskId]/+page.svelte
Normal file
@@ -0,0 +1,347 @@
|
||||
<script lang="ts">
|
||||
import { page } from "$app/stores";
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
fetchTraceStats,
|
||||
getTraceRawUrl,
|
||||
type TraceStatsResponse,
|
||||
type TraceCategoryStats,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
const taskId = $derived($page.params.taskId);
|
||||
|
||||
let stats = $state<TraceStatsResponse | null>(null);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatDuration(us: number): string {
|
||||
if (us < 1000) return `${us.toFixed(0)}us`;
|
||||
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
|
||||
return `${(us / 1_000_000).toFixed(2)}s`;
|
||||
}
|
||||
|
||||
function formatPercentage(part: number, total: number): string {
|
||||
if (total === 0) return "0.0%";
|
||||
return `${((part / total) * 100).toFixed(1)}%`;
|
||||
}
|
||||
|
||||
// Parse hierarchical categories like "sync/compute" into phases
|
||||
type PhaseData = {
|
||||
name: string;
|
||||
subcategories: { name: string; stats: TraceCategoryStats }[];
|
||||
totalUs: number; // From outer span (e.g., "sync" category)
|
||||
stepCount: number; // Count of outer span events
|
||||
};
|
||||
|
||||
function parsePhases(
|
||||
byCategory: Record<string, TraceCategoryStats>,
|
||||
): PhaseData[] {
|
||||
const phases = new Map<
|
||||
string,
|
||||
{
|
||||
subcats: Map<string, TraceCategoryStats>;
|
||||
outerStats: TraceCategoryStats | null;
|
||||
}
|
||||
>();
|
||||
|
||||
for (const [category, catStats] of Object.entries(byCategory)) {
|
||||
if (category.includes("/")) {
|
||||
const [phase, subcat] = category.split("/", 2);
|
||||
if (!phases.has(phase)) {
|
||||
phases.set(phase, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(phase)!.subcats.set(subcat, catStats);
|
||||
} else {
|
||||
// Outer span - this IS the phase total
|
||||
if (!phases.has(category)) {
|
||||
phases.set(category, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(category)!.outerStats = catStats;
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(phases.entries())
|
||||
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
|
||||
.map(([name, data]) => ({
|
||||
name,
|
||||
subcategories: Array.from(data.subcats.entries())
|
||||
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
|
||||
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
|
||||
totalUs: data.outerStats!.totalUs, // Outer span total
|
||||
stepCount: data.outerStats!.count, // Number of steps
|
||||
}))
|
||||
.sort((a, b) => b.totalUs - a.totalUs);
|
||||
}
|
||||
|
||||
async function openInPerfetto() {
|
||||
if (!taskId) return;
|
||||
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
onMount(async () => {
|
||||
if (!taskId) {
|
||||
error = "No task ID provided";
|
||||
loading = false;
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
stats = await fetchTraceStats(taskId);
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load trace";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
});
|
||||
|
||||
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
|
||||
const sortedRanks = $derived(
|
||||
stats
|
||||
? Object.keys(stats.byRank)
|
||||
.map(Number)
|
||||
.sort((a, b) => a - b)
|
||||
: [],
|
||||
);
|
||||
const nodeCount = $derived(sortedRanks.length || 1);
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Trace
|
||||
</h1>
|
||||
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
|
||||
{taskId}
|
||||
</p>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<a
|
||||
href="#/traces"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
|
||||
>
|
||||
All Traces
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
|
||||
onclick={openInPerfetto}
|
||||
disabled={loading || !!error}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading trace data...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if stats}
|
||||
<!-- Wall Time Summary -->
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
Summary
|
||||
</h2>
|
||||
<div class="text-3xl font-mono text-exo-yellow">
|
||||
{formatDuration(stats.totalWallTimeUs)}
|
||||
</div>
|
||||
<div class="text-xs text-exo-light-gray">Total wall time</div>
|
||||
</div>
|
||||
|
||||
<!-- By Phase -->
|
||||
{#if phases.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
|
||||
</h2>
|
||||
<div class="space-y-4">
|
||||
{#each phases as phase}
|
||||
{@const normalizedTotal = phase.totalUs / nodeCount}
|
||||
{@const normalizedStepCount = phase.stepCount / nodeCount}
|
||||
<div class="space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-sm font-mono text-white">{phase.name}</span>
|
||||
<span class="text-sm font-mono">
|
||||
<span class="text-exo-yellow"
|
||||
>{formatDuration(normalizedTotal)}</span
|
||||
>
|
||||
<span class="text-exo-light-gray ml-2">
|
||||
({normalizedStepCount} steps, {formatDuration(
|
||||
normalizedTotal / normalizedStepCount,
|
||||
)}/step)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-4 space-y-1.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const normalizedSubcat =
|
||||
subcat.stats.totalUs / nodeCount}
|
||||
{@const pct = formatPercentage(
|
||||
normalizedSubcat,
|
||||
normalizedTotal,
|
||||
)}
|
||||
{@const perStep = normalizedSubcat / normalizedStepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-xs font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray">{subcat.name}</span>
|
||||
<span class="text-white">
|
||||
{formatDuration(normalizedSubcat)}
|
||||
<span class="text-exo-light-gray ml-2">({pct})</span>
|
||||
<span class="text-exo-light-gray/60 ml-2"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
<!-- Progress bar -->
|
||||
<div
|
||||
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
style="width: {pct}"
|
||||
></div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- By Rank -->
|
||||
{#if sortedRanks.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Rank
|
||||
</h2>
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{#each sortedRanks as rank}
|
||||
{@const rankStats = stats.byRank[rank]}
|
||||
{@const rankPhases = parsePhases(rankStats.byCategory)}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
|
||||
>
|
||||
<div class="text-sm font-mono text-exo-yellow">
|
||||
Rank {rank}
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
{#each rankPhases as phase}
|
||||
<div class="space-y-1">
|
||||
<div class="flex items-center justify-between text-xs">
|
||||
<span class="font-mono text-exo-light-gray"
|
||||
>{phase.name}</span
|
||||
>
|
||||
<span class="font-mono text-white">
|
||||
{formatDuration(phase.totalUs)}
|
||||
<span class="text-exo-light-gray/50 ml-1">
|
||||
({phase.stepCount}x)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-2 space-y-0.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const pct = formatPercentage(
|
||||
subcat.stats.totalUs,
|
||||
phase.totalUs,
|
||||
)}
|
||||
{@const perStep =
|
||||
subcat.stats.totalUs / phase.stepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-[10px] font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray/70"
|
||||
>{subcat.name}</span
|
||||
>
|
||||
<span class="text-exo-light-gray">
|
||||
{formatDuration(subcat.stats.totalUs)}
|
||||
<span class="text-exo-light-gray/50"
|
||||
>({pct})</span
|
||||
>
|
||||
<span class="text-exo-light-gray/30 ml-1"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
65
flake.lock
generated
65
flake.lock
generated
@@ -21,7 +21,9 @@
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
@@ -149,19 +151,44 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
],
|
||||
"uv2nix": [
|
||||
"uv2nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763662255,
|
||||
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"lastModified": 1764134915,
|
||||
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -178,7 +205,10 @@
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix",
|
||||
"treefmt-nix": "treefmt-nix",
|
||||
"uv2nix": "uv2nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
@@ -239,6 +269,29 @@
|
||||
"repo": "treefmt-nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"uv2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1767701098,
|
||||
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
||||
27
flake.nix
27
flake.nix
@@ -24,6 +24,26 @@
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
};
|
||||
|
||||
# Python packaging with uv2nix
|
||||
pyproject-nix = {
|
||||
url = "github:pyproject-nix/pyproject.nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
uv2nix = {
|
||||
url = "github:pyproject-nix/uv2nix";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
pyproject-build-systems = {
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.uv2nix.follows = "uv2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
@@ -48,6 +68,7 @@
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
./python/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
@@ -88,12 +109,6 @@
|
||||
};
|
||||
};
|
||||
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
|
||||
@@ -10,7 +10,6 @@ PROJECT_ROOT = Path.cwd()
|
||||
SOURCE_ROOT = PROJECT_ROOT / "src"
|
||||
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
|
||||
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
|
||||
RESOURCES_DIR = PROJECT_ROOT / "resources"
|
||||
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
|
||||
|
||||
if not ENTRYPOINT.is_file():
|
||||
@@ -19,9 +18,6 @@ if not ENTRYPOINT.is_file():
|
||||
if not DASHBOARD_DIR.is_dir():
|
||||
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
|
||||
|
||||
if not RESOURCES_DIR.is_dir():
|
||||
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
|
||||
|
||||
if not EXO_SHARED_MODELS_DIR.is_dir():
|
||||
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
|
||||
|
||||
@@ -62,7 +58,6 @@ HIDDEN_IMPORTS = sorted(
|
||||
|
||||
DATAS: list[tuple[str, str]] = [
|
||||
(str(DASHBOARD_DIR), "dashboard"),
|
||||
(str(RESOURCES_DIR), "resources"),
|
||||
(str(MLX_LIB_DIR), "mlx/lib"),
|
||||
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
|
||||
]
|
||||
|
||||
93
python/parts.nix
Normal file
93
python/parts.nix
Normal file
@@ -0,0 +1,93 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', pkgs, lib, system, ... }:
|
||||
let
|
||||
# Load workspace from uv.lock
|
||||
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
|
||||
workspaceRoot = inputs.self;
|
||||
};
|
||||
|
||||
# Create overlay from workspace
|
||||
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
|
||||
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
src = self'.packages.exo_pyo3_bindings;
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
};
|
||||
};
|
||||
|
||||
python = pkgs.python313;
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
};
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
lib.composeManyExtensions [
|
||||
inputs.pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
);
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
}
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set DASHBOARD_DIR ${self'.packages.dashboard}
|
||||
done
|
||||
'';
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS (requires MLX/Metal)
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15470210592
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5945589280
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21415799872
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11891178560
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-schnell"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33306978432
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23782357120
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-4bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 10215200544
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-8bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 20430401088
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 10215200544
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 20430401088
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 40860802176
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 40860802176
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-8bit"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-bf16"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-4bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-6bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 19327352832
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 22548578304
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26843545600
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 34359738368
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 620622774272
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Kimi-K2-Thinking"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Kimi-K2.5"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 662498705408
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
n_layers = 16
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 729808896
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1863319552
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 3501195264
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 76799803392
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 4637851648
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 8954839040
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 16882073600
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-3bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 289910292480
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 579820584960
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 46976204800
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
n_layers = 36
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 70652212224
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
n_layers = 24
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 12025908224
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 144383672320
|
||||
@@ -7,7 +7,7 @@ from loguru import logger
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.fetch_from_hf(model_id)
|
||||
model_card = await ModelCard.load(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -160,15 +160,14 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
# Kick off download status coroutines concurrently
|
||||
tasks = [
|
||||
asyncio.create_task(_status_for_model(model_card.model_id))
|
||||
for model_card in await get_model_cards()
|
||||
for model_card in MODEL_CARDS.values()
|
||||
]
|
||||
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
logger.error("Error downloading shard:", e)
|
||||
logger.warning(f"Error downloading shard: {type(e).__name__}")
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
|
||||
@@ -3,7 +3,9 @@ import contextlib
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -22,14 +24,18 @@ from loguru import logger
|
||||
from exo.master.image_store import ImageStore
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import DASHBOARD_DIR, EXO_IMAGE_CACHE_DIR, EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.constants import (
|
||||
EXO_IMAGE_CACHE_DIR,
|
||||
EXO_MAX_CHUNK_SIZE,
|
||||
)
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import (
|
||||
MODEL_CARDS,
|
||||
ModelCard,
|
||||
ModelId,
|
||||
get_model_cards,
|
||||
)
|
||||
from exo.shared.tracing import compute_stats, load_trace_file
|
||||
from exo.shared.types.api import (
|
||||
AdvancedImageParams,
|
||||
BenchChatCompletionResponse,
|
||||
@@ -62,7 +68,16 @@ from exo.shared.types.api import (
|
||||
StartDownloadParams,
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
StreamOptions,
|
||||
ToolCall,
|
||||
TraceCategoryStats,
|
||||
TraceEventResponse,
|
||||
TraceListItem,
|
||||
TraceListResponse,
|
||||
TraceRankStats,
|
||||
TraceResponse,
|
||||
TraceStatsResponse,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
@@ -101,6 +116,7 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
|
||||
@@ -109,7 +125,9 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
chunk: TokenChunk | ToolCallChunk,
|
||||
command_id: CommandId,
|
||||
usage: Usage | None,
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -134,6 +152,7 @@ def chunk_to_response(
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -174,7 +193,7 @@ class API:
|
||||
self.app.mount(
|
||||
"/",
|
||||
StaticFiles(
|
||||
directory=DASHBOARD_DIR,
|
||||
directory=find_dashboard(),
|
||||
html=True,
|
||||
),
|
||||
name="dashboard",
|
||||
@@ -255,6 +274,10 @@ class API:
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.get("/v1/traces")(self.list_traces)
|
||||
self.app.get("/v1/traces/{task_id}")(self.get_trace)
|
||||
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
|
||||
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -345,7 +368,10 @@ class API:
|
||||
if len(list(self.state.topology.list_nodes())) == 0:
|
||||
return PlacementPreviewResponse(previews=[])
|
||||
|
||||
model_card = await ModelCard.load(model_id)
|
||||
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
|
||||
if not cards:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
|
||||
|
||||
instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []
|
||||
for sharding in (Sharding.Pipeline, Sharding.Tensor):
|
||||
for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):
|
||||
@@ -360,93 +386,96 @@ class API:
|
||||
# TODO: PDD
|
||||
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
|
||||
|
||||
for sharding, instance_meta, min_nodes in instance_combinations:
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
required_nodes=required_nodes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
for model_card in cards:
|
||||
for sharding, instance_meta, min_nodes in instance_combinations:
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error=str(exc),
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
required_nodes=required_nodes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
current_ids = set(self.state.instances.keys())
|
||||
new_instances = [
|
||||
instance
|
||||
for instance_id, instance in placements.items()
|
||||
if instance_id not in current_ids
|
||||
]
|
||||
current_ids = set(self.state.instances.keys())
|
||||
new_instances = [
|
||||
instance
|
||||
for instance_id, instance in placements.items()
|
||||
if instance_id not in current_ids
|
||||
]
|
||||
|
||||
if len(new_instances) != 1:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error="Expected exactly one new instance from placement",
|
||||
if len(new_instances) != 1:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error="Expected exactly one new instance from placement",
|
||||
)
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
instance = new_instances[0]
|
||||
shard_assignments = instance.shard_assignments
|
||||
placement_node_ids = list(shard_assignments.node_to_runner.keys())
|
||||
instance = new_instances[0]
|
||||
shard_assignments = instance.shard_assignments
|
||||
placement_node_ids = list(shard_assignments.node_to_runner.keys())
|
||||
|
||||
memory_delta_by_node: dict[str, int] = {}
|
||||
if placement_node_ids:
|
||||
total_bytes = model_card.storage_size.in_bytes
|
||||
per_node = total_bytes // len(placement_node_ids)
|
||||
remainder = total_bytes % len(placement_node_ids)
|
||||
for index, node_id in enumerate(sorted(placement_node_ids, key=str)):
|
||||
extra = 1 if index < remainder else 0
|
||||
memory_delta_by_node[str(node_id)] = per_node + extra
|
||||
memory_delta_by_node: dict[str, int] = {}
|
||||
if placement_node_ids:
|
||||
total_bytes = model_card.storage_size.in_bytes
|
||||
per_node = total_bytes // len(placement_node_ids)
|
||||
remainder = total_bytes % len(placement_node_ids)
|
||||
for index, node_id in enumerate(
|
||||
sorted(placement_node_ids, key=str)
|
||||
):
|
||||
extra = 1 if index < remainder else 0
|
||||
memory_delta_by_node[str(node_id)] = per_node + extra
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
seen.add(
|
||||
(
|
||||
if (
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return PlacementPreviewResponse(previews=previews)
|
||||
|
||||
@@ -500,9 +529,10 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, stream_options: StreamOptions | None = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
include_usage = stream_options.include_usage if stream_options else False
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
@@ -518,8 +548,10 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
usage = chunk.usage if include_usage else None
|
||||
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
chunk, command_id, usage=usage
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
@@ -535,8 +567,9 @@ class API:
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
model: ModelId | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
@@ -561,6 +594,9 @@ class API:
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.usage is not None:
|
||||
usage = chunk.usage
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
@@ -582,6 +618,7 @@ class API:
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
@@ -589,7 +626,7 @@ class API:
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
model: ModelId | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
@@ -642,7 +679,7 @@ class API:
|
||||
)
|
||||
return resp
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
)
|
||||
@@ -669,7 +706,7 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
self._generate_chat_stream(command.command_id, payload.stream_options),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -698,12 +735,12 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
async def _validate_image_model(self, model: ModelId) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
"""
|
||||
model_card = await ModelCard.load(ModelId(model))
|
||||
model_card = await ModelCard.load(model)
|
||||
resolved_model = model_card.model_id
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
@@ -749,7 +786,7 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -994,7 +1031,7 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
@@ -1015,7 +1052,7 @@ class API:
|
||||
self,
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: str,
|
||||
model: ModelId,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
@@ -1110,7 +1147,7 @@ class API:
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1166,7 +1203,7 @@ class API:
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1209,7 +1246,7 @@ class API:
|
||||
supports_tensor=card.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
)
|
||||
for card in await get_model_cards()
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1313,3 +1350,110 @@ class API:
|
||||
)
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
def _get_traces_dir(self) -> Path:
|
||||
return Path.home() / ".exo" / "traces"
|
||||
|
||||
def _get_trace_path(self, task_id: str) -> Path:
|
||||
return self._get_traces_dir() / f"trace_{task_id}.json"
|
||||
|
||||
async def list_traces(self) -> TraceListResponse:
|
||||
traces_dir = self._get_traces_dir()
|
||||
traces: list[TraceListItem] = []
|
||||
|
||||
if not traces_dir.exists():
|
||||
return TraceListResponse(traces=[])
|
||||
|
||||
for trace_file in sorted(
|
||||
traces_dir.glob("trace_*.json"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
):
|
||||
# Extract task_id from filename (trace_{task_id}.json)
|
||||
task_id = trace_file.stem.removeprefix("trace_")
|
||||
stat = trace_file.stat()
|
||||
created_at = datetime.fromtimestamp(
|
||||
stat.st_mtime, tz=timezone.utc
|
||||
).isoformat()
|
||||
traces.append(
|
||||
TraceListItem(
|
||||
task_id=task_id,
|
||||
created_at=created_at,
|
||||
file_size=stat.st_size,
|
||||
)
|
||||
)
|
||||
|
||||
return TraceListResponse(traces=traces)
|
||||
|
||||
async def get_trace(self, task_id: str) -> TraceResponse:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
|
||||
if not trace_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
|
||||
|
||||
trace_events = load_trace_file(trace_path)
|
||||
|
||||
return TraceResponse(
|
||||
task_id=task_id,
|
||||
traces=[
|
||||
TraceEventResponse(
|
||||
name=event.name,
|
||||
start_us=event.start_us,
|
||||
duration_us=event.duration_us,
|
||||
rank=event.rank,
|
||||
category=event.category,
|
||||
)
|
||||
for event in trace_events
|
||||
],
|
||||
)
|
||||
|
||||
async def get_trace_stats(self, task_id: str) -> TraceStatsResponse:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
|
||||
if not trace_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
|
||||
|
||||
trace_events = load_trace_file(trace_path)
|
||||
stats = compute_stats(trace_events)
|
||||
|
||||
return TraceStatsResponse(
|
||||
task_id=task_id,
|
||||
total_wall_time_us=stats.total_wall_time_us,
|
||||
by_category={
|
||||
category: TraceCategoryStats(
|
||||
total_us=cat_stats.total_us,
|
||||
count=cat_stats.count,
|
||||
min_us=cat_stats.min_us,
|
||||
max_us=cat_stats.max_us,
|
||||
avg_us=cat_stats.avg_us,
|
||||
)
|
||||
for category, cat_stats in stats.by_category.items()
|
||||
},
|
||||
by_rank={
|
||||
rank: TraceRankStats(
|
||||
by_category={
|
||||
category: TraceCategoryStats(
|
||||
total_us=cat_stats.total_us,
|
||||
count=cat_stats.count,
|
||||
min_us=cat_stats.min_us,
|
||||
max_us=cat_stats.max_us,
|
||||
avg_us=cat_stats.avg_us,
|
||||
)
|
||||
for category, cat_stats in rank_stats.items()
|
||||
}
|
||||
)
|
||||
for rank, rank_stats in stats.by_rank.items()
|
||||
},
|
||||
)
|
||||
|
||||
async def get_trace_raw(self, task_id: str) -> FileResponse:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
|
||||
if not trace_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
|
||||
|
||||
return FileResponse(
|
||||
path=trace_path,
|
||||
media_type="application/json",
|
||||
filename=f"trace_{task_id}.json",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -11,6 +12,7 @@ from exo.master.placement import (
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.tracing import TraceEvent, export_trace, is_tracing_enabled
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CreateInstance,
|
||||
@@ -35,6 +37,8 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
@@ -86,6 +90,8 @@ class Master:
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
# TODO: not have this
|
||||
self._event_log: list[Event] = []
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -187,13 +193,14 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
instance_id=selected_instance_id,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
@@ -201,6 +208,17 @@ class Master:
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if is_tracing_enabled():
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case ImageEdits():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
@@ -229,13 +247,14 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
instance_id=selected_instance_id,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
@@ -243,6 +262,17 @@ class Master:
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if is_tracing_enabled():
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
@@ -335,6 +365,10 @@ class Master:
|
||||
local_event.origin,
|
||||
)
|
||||
for event in self._multi_buffer.drain():
|
||||
if isinstance(event, TracesCollected):
|
||||
self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
@@ -373,3 +407,38 @@ class Master:
|
||||
event=event.event,
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_traces_collected(self, event: TracesCollected) -> None:
|
||||
task_id = event.task_id
|
||||
if task_id not in self._pending_traces:
|
||||
self._pending_traces[task_id] = {}
|
||||
self._pending_traces[task_id][event.rank] = event.traces
|
||||
|
||||
if (
|
||||
task_id in self._expected_ranks
|
||||
and set(self._pending_traces[task_id].keys())
|
||||
>= self._expected_ranks[task_id]
|
||||
):
|
||||
self._merge_and_save_traces(task_id)
|
||||
|
||||
def _merge_and_save_traces(self, task_id: TaskId) -> None:
|
||||
all_traces: list[TraceEvent] = []
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
for t in trace_data:
|
||||
all_traces.append(
|
||||
TraceEvent(
|
||||
name=t.name,
|
||||
start_us=t.start_us,
|
||||
duration_us=t.duration_us,
|
||||
rank=t.rank,
|
||||
category=t.category,
|
||||
)
|
||||
)
|
||||
|
||||
output_path = Path.home() / ".exo" / "traces" / f"trace_{task_id}.json"
|
||||
export_trace(all_traces, output_path)
|
||||
logger.info(f"Merged traces saved to {output_path}")
|
||||
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
@@ -216,6 +216,8 @@ def get_node_id_keypair(
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
@@ -25,6 +25,7 @@ from exo.shared.types.events import (
|
||||
TestEvent,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
TracesCollected,
|
||||
)
|
||||
from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
@@ -55,7 +56,11 @@ def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case (
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
|
||||
TestEvent()
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| TracesCollected()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from exo.utils.dashboard_path import find_dashboard, find_resources
|
||||
|
||||
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
|
||||
|
||||
|
||||
@@ -33,14 +31,6 @@ EXO_MODELS_DIR = (
|
||||
if _EXO_MODELS_DIR_ENV is None
|
||||
else Path.home() / _EXO_MODELS_DIR_ENV
|
||||
)
|
||||
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
|
||||
RESOURCES_DIR = (
|
||||
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
|
||||
)
|
||||
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
|
||||
DASHBOARD_DIR = (
|
||||
find_dashboard() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
|
||||
)
|
||||
|
||||
# Log files (data/logs or cache)
|
||||
EXO_LOG = EXO_CACHE_HOME / "exo.log"
|
||||
@@ -63,3 +53,9 @@ EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
|
||||
EXO_ENABLE_IMAGE_MODELS = (
|
||||
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
EXO_TRACING_ENABLED = os.getenv("EXO_TRACING_ENABLED", "").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
|
||||
@@ -12,42 +12,16 @@ from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PositiveInt,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from tomlkit.exceptions import TOMLKitError
|
||||
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS, RESOURCES_DIR
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
# kinda ugly...
|
||||
# TODO: load search path from config.toml
|
||||
_csp = [Path(RESOURCES_DIR)]
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
_csp.append(Path(RESOURCES_DIR) / "image_models")
|
||||
|
||||
CARD_SEARCH_PATH = _csp
|
||||
|
||||
_card_cache: dict[ModelId, "ModelCard"] = {}
|
||||
|
||||
|
||||
async def _populate_card_cache():
|
||||
for path in CARD_SEARCH_PATH:
|
||||
async for toml_file in path.rglob("*.toml"):
|
||||
try:
|
||||
card = await ModelCard.load_from_path(toml_file)
|
||||
_card_cache[card.model_id] = card
|
||||
except (ValidationError, TOMLKitError):
|
||||
pass
|
||||
|
||||
|
||||
async def get_model_cards() -> list["ModelCard"]:
|
||||
if len(_card_cache) == 0:
|
||||
await _populate_card_cache()
|
||||
return list(_card_cache.values())
|
||||
_card_cache: dict[str, "ModelCard"] = {}
|
||||
|
||||
|
||||
class ModelTask(str, Enum):
|
||||
@@ -81,37 +55,28 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
async def save(self, path: Path) -> None:
|
||||
async with await open_file(path, "w") as f:
|
||||
py = self.model_dump(exclude_none=True)
|
||||
py = self.model_dump()
|
||||
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
|
||||
await f.write(data)
|
||||
|
||||
async def save_to_default_path(self):
|
||||
await self.save(Path(RESOURCES_DIR) / (self.model_id.normalize() + ".toml"))
|
||||
|
||||
@staticmethod
|
||||
async def load_from_path(path: Path) -> "ModelCard":
|
||||
async with await open_file(path, "r") as f:
|
||||
py = tomlkit.loads(await f.read())
|
||||
return ModelCard.model_validate(py)
|
||||
|
||||
# Is it okay that model card.load defaults to network access if the card doesn't exist? do we want to be more explicit here?
|
||||
@staticmethod
|
||||
async def load(model_id: ModelId) -> "ModelCard":
|
||||
if len(_card_cache) == 0:
|
||||
await _populate_card_cache()
|
||||
if (mc := _card_cache.get(model_id)) is not None:
|
||||
return mc
|
||||
|
||||
return await ModelCard.fetch_from_hf(model_id)
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == model_id:
|
||||
return card
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
@staticmethod
|
||||
async def fetch_from_hf(model_id: ModelId) -> "ModelCard":
|
||||
async def from_hf(model_id: ModelId) -> "ModelCard":
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
if len(_card_cache) == 0:
|
||||
await _populate_card_cache()
|
||||
if (mc := _card_cache.get(model_id)) is not None:
|
||||
return mc
|
||||
# TODO: failure if files do not exist
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
@@ -124,13 +89,544 @@ class ModelCard(CamelCaseModel):
|
||||
supports_tensor=config_data.supports_tensor,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
await mc.save_to_default_path()
|
||||
_card_cache[model_id] = mc
|
||||
return mc
|
||||
|
||||
|
||||
# TODO: quantizing and dynamically creating model cards
|
||||
def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunction]
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2.5": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2.5"),
|
||||
storage_size=Memory.from_gb(617),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.2
|
||||
"llama-3.2-1b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.3
|
||||
"llama-3.3-70b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# qwen3
|
||||
"qwen3-0.6b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# glm 4.7 flash
|
||||
"glm-4.7-flash-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
|
||||
storage_size=Memory.from_gb(18),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-flash-5bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
|
||||
storage_size=Memory.from_gb(21),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-flash-6bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
|
||||
storage_size=Memory.from_gb(25),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-flash-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
|
||||
storage_size=Memory.from_gb(32),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
}
|
||||
|
||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("exolabs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("exolabs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-krea-dev": ModelCard(
|
||||
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
|
||||
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("exolabs/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _generate_image_model_quant_variants(
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
@@ -210,6 +706,15 @@ def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunctio
|
||||
return variants
|
||||
|
||||
|
||||
_image_model_cards: dict[str, ModelCard] = {}
|
||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
||||
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
|
||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
|
||||
from multiprocessing.synchronize import Semaphore as SemaphoreT
|
||||
|
||||
from loguru import logger
|
||||
from pytest import LogCaptureFixture
|
||||
from pytest import LogCaptureFixture, mark
|
||||
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
@@ -74,6 +74,7 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
|
||||
os.remove(p)
|
||||
|
||||
|
||||
@mark.skip(reason="this functionality is currently disabled but may return in future")
|
||||
def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
reps = 10
|
||||
|
||||
|
||||
450
src/exo/shared/tracing.py
Normal file
450
src/exo/shared/tracing.py
Normal file
@@ -0,0 +1,450 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import cast, final
|
||||
|
||||
from exo.shared.constants import EXO_TRACING_ENABLED
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable to track the current trace category for hierarchical nesting
|
||||
_current_category: ContextVar[str | None] = ContextVar("current_category", default=None)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass(frozen=True)
|
||||
class TraceEvent:
|
||||
name: str
|
||||
start_us: int
|
||||
duration_us: int
|
||||
rank: int
|
||||
category: str
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class CategoryStats:
|
||||
total_us: int = 0
|
||||
count: int = 0
|
||||
min_us: int = 0
|
||||
max_us: int = 0
|
||||
|
||||
def add(self, duration_us: int) -> None:
|
||||
if self.count == 0:
|
||||
self.min_us = duration_us
|
||||
self.max_us = duration_us
|
||||
else:
|
||||
self.min_us = min(self.min_us, duration_us)
|
||||
self.max_us = max(self.max_us, duration_us)
|
||||
self.total_us += duration_us
|
||||
self.count += 1
|
||||
|
||||
@property
|
||||
def avg_us(self) -> float:
|
||||
return self.total_us / self.count if self.count > 0 else 0.0
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TraceStats:
|
||||
total_wall_time_us: int = 0
|
||||
by_category: dict[str, CategoryStats] = field(default_factory=dict)
|
||||
by_rank: dict[int, dict[str, CategoryStats]] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Global trace buffer - each rank accumulates traces here
|
||||
_trace_buffer: list[TraceEvent] = []
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
"""Check if tracing is enabled via environment variable."""
|
||||
return EXO_TRACING_ENABLED
|
||||
|
||||
|
||||
def _record_span(
|
||||
name: str, start_us: int, duration_us: int, rank: int, category: str
|
||||
) -> None:
|
||||
_trace_buffer.append(
|
||||
TraceEvent(
|
||||
name=name,
|
||||
start_us=start_us,
|
||||
duration_us=duration_us,
|
||||
rank=rank,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace(
|
||||
name: str,
|
||||
rank: int,
|
||||
category: str = "compute",
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager to trace any operation.
|
||||
|
||||
Nested traces automatically inherit the parent category, creating hierarchical
|
||||
categories like "sync/compute" or "async/comms".
|
||||
|
||||
Args:
|
||||
name: Name of the operation (e.g., "recv 0", "send 1", "joint_blocks")
|
||||
rank: This rank's ID
|
||||
category: Category for grouping in trace viewer ("comm", "compute", "step")
|
||||
|
||||
Example:
|
||||
with trace(f"sync {t}", rank, "sync"):
|
||||
with trace("joint_blocks", rank, "compute"):
|
||||
# Recorded with category "sync/compute"
|
||||
hidden_states = some_computation(...)
|
||||
"""
|
||||
if not is_tracing_enabled():
|
||||
yield
|
||||
return
|
||||
|
||||
# Combine with parent category if nested
|
||||
parent = _current_category.get()
|
||||
full_category = f"{parent}/{category}" if parent else category
|
||||
|
||||
# Set as current for nested traces
|
||||
token = _current_category.set(full_category)
|
||||
|
||||
try:
|
||||
start_us = int(time.time() * 1_000_000)
|
||||
start_perf = time.perf_counter()
|
||||
yield
|
||||
duration_us = int((time.perf_counter() - start_perf) * 1_000_000)
|
||||
_record_span(name, start_us, duration_us, rank, full_category)
|
||||
finally:
|
||||
_current_category.reset(token)
|
||||
|
||||
|
||||
def get_trace_buffer() -> list[TraceEvent]:
|
||||
return list(_trace_buffer)
|
||||
|
||||
|
||||
def clear_trace_buffer() -> None:
|
||||
_trace_buffer.clear()
|
||||
|
||||
|
||||
def export_trace(traces: list[TraceEvent], output_path: Path) -> None:
|
||||
trace_events: list[dict[str, object]] = []
|
||||
|
||||
for event in traces:
|
||||
# Chrome trace format uses "X" for complete events (with duration)
|
||||
chrome_event: dict[str, object] = {
|
||||
"name": event.name,
|
||||
"cat": event.category,
|
||||
"ph": "X",
|
||||
"ts": event.start_us,
|
||||
"dur": event.duration_us,
|
||||
"pid": 0,
|
||||
"tid": event.rank,
|
||||
"args": {"rank": event.rank},
|
||||
}
|
||||
trace_events.append(chrome_event)
|
||||
|
||||
ranks_seen = set(t.rank for t in traces)
|
||||
for rank in ranks_seen:
|
||||
trace_events.append(
|
||||
{
|
||||
"name": "thread_name",
|
||||
"ph": "M", # Metadata event
|
||||
"pid": 0,
|
||||
"tid": rank,
|
||||
"args": {"name": f"Rank {rank}"},
|
||||
}
|
||||
)
|
||||
|
||||
chrome_trace = {"traceEvents": trace_events}
|
||||
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(chrome_trace, f, indent=2)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to export trace to %s: %s", output_path, e)
|
||||
|
||||
|
||||
def export_local_traces(rank: int) -> None:
|
||||
if not is_tracing_enabled():
|
||||
return
|
||||
|
||||
local_traces = get_trace_buffer()
|
||||
if local_traces:
|
||||
output_path = Path.home() / ".exo" / "traces" / f"trace_{rank}.json"
|
||||
try:
|
||||
export_trace(local_traces, output_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to export local traces for rank %d: %s", rank, e)
|
||||
|
||||
clear_trace_buffer()
|
||||
|
||||
|
||||
def merge_trace_files(trace_dir: Path | None = None) -> Path | None:
|
||||
if trace_dir is None:
|
||||
trace_dir = Path.home() / ".exo" / "traces"
|
||||
|
||||
if not trace_dir.exists():
|
||||
return None
|
||||
|
||||
trace_files = sorted(trace_dir.glob("trace_*.json"))
|
||||
|
||||
if not trace_files:
|
||||
return None
|
||||
|
||||
merged_events: list[dict[str, object]] = []
|
||||
for trace_file in trace_files:
|
||||
file_rank = int(trace_file.stem.split("_")[1])
|
||||
|
||||
with open(trace_file) as f:
|
||||
raw = f.read()
|
||||
data = cast(dict[str, list[dict[str, object]]], json.loads(raw))
|
||||
events: list[dict[str, object]] = data.get("traceEvents", [])
|
||||
for event in events:
|
||||
event["tid"] = file_rank
|
||||
if "args" in event and isinstance(event["args"], dict):
|
||||
event["args"]["rank"] = file_rank
|
||||
merged_events.extend(events)
|
||||
|
||||
output_path = Path.home() / ".exo" / "traces" / "merged_trace.json"
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump({"traceEvents": merged_events}, f, indent=2)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to write merged trace to %s: %s", output_path, e)
|
||||
return None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def _format_duration(us: int | float) -> str:
|
||||
if us < 1000:
|
||||
return f"{us:.0f}µs"
|
||||
elif us < 1_000_000:
|
||||
return f"{us / 1000:.2f}ms"
|
||||
else:
|
||||
return f"{us / 1_000_000:.2f}s"
|
||||
|
||||
|
||||
def load_trace_file(path: Path) -> list[TraceEvent]:
|
||||
"""Load a Chrome Trace Format JSON file into TraceEvent objects."""
|
||||
with open(path) as f:
|
||||
data = cast(dict[str, list[dict[str, object]]], json.load(f))
|
||||
|
||||
events = data.get("traceEvents", [])
|
||||
traces: list[TraceEvent] = []
|
||||
|
||||
for event in events:
|
||||
# Skip metadata events
|
||||
if event.get("ph") == "M":
|
||||
continue
|
||||
|
||||
name = str(event.get("name", ""))
|
||||
category = str(event.get("cat", ""))
|
||||
ts_value = event.get("ts", 0)
|
||||
dur_value = event.get("dur", 0)
|
||||
tid_value = event.get("tid", 0)
|
||||
start_us = int(ts_value) if isinstance(ts_value, (int, float, str)) else 0
|
||||
duration_us = int(dur_value) if isinstance(dur_value, (int, float, str)) else 0
|
||||
|
||||
# Get rank from tid or args
|
||||
rank = int(tid_value) if isinstance(tid_value, (int, float, str)) else 0
|
||||
args = event.get("args")
|
||||
if isinstance(args, dict):
|
||||
args_dict = cast(dict[str, object], args)
|
||||
rank_from_args = args_dict.get("rank")
|
||||
if isinstance(rank_from_args, (int, float, str)):
|
||||
rank = int(rank_from_args)
|
||||
|
||||
traces.append(
|
||||
TraceEvent(
|
||||
name=name,
|
||||
start_us=start_us,
|
||||
duration_us=duration_us,
|
||||
rank=rank,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
return traces
|
||||
|
||||
|
||||
def compute_stats(traces: list[TraceEvent]) -> TraceStats:
|
||||
"""Compute comprehensive statistics from trace events."""
|
||||
stats = TraceStats()
|
||||
|
||||
if not traces:
|
||||
return stats
|
||||
|
||||
# Calculate wall time from earliest start to latest end
|
||||
min_start = min(t.start_us for t in traces)
|
||||
max_end = max(t.start_us + t.duration_us for t in traces)
|
||||
stats.total_wall_time_us = max_end - min_start
|
||||
|
||||
# Initialize nested dicts
|
||||
by_category: dict[str, CategoryStats] = defaultdict(CategoryStats)
|
||||
by_rank: dict[int, dict[str, CategoryStats]] = defaultdict(
|
||||
lambda: defaultdict(CategoryStats)
|
||||
)
|
||||
|
||||
for event in traces:
|
||||
# By category
|
||||
by_category[event.category].add(event.duration_us)
|
||||
|
||||
# By rank and category
|
||||
by_rank[event.rank][event.category].add(event.duration_us)
|
||||
|
||||
stats.by_category = dict(by_category)
|
||||
stats.by_rank = {k: dict(v) for k, v in by_rank.items()}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def print_stats(stats: TraceStats) -> None:
|
||||
"""Print formatted trace statistics."""
|
||||
print("=== Trace Statistics ===")
|
||||
print()
|
||||
print(f"Wall Time: {_format_duration(stats.total_wall_time_us)}")
|
||||
print()
|
||||
|
||||
# Parse hierarchical categories (e.g., "sync/compute" -> phase="sync", subcat="compute")
|
||||
if stats.by_category:
|
||||
phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
|
||||
has_hierarchical = False
|
||||
|
||||
for cat, cat_stats in stats.by_category.items():
|
||||
if "/" in cat:
|
||||
phase, subcat = cat.split("/", 1)
|
||||
phases[phase][subcat] = cat_stats
|
||||
has_hierarchical = True
|
||||
else:
|
||||
phases[cat]["_total"] = cat_stats
|
||||
|
||||
if has_hierarchical:
|
||||
print("By Phase:")
|
||||
for phase in sorted(phases.keys()):
|
||||
subcats = phases[phase]
|
||||
# Skip phases that only have _total (non-hierarchical top-level categories)
|
||||
non_total_subcats = {k: v for k, v in subcats.items() if k != "_total"}
|
||||
if not non_total_subcats:
|
||||
continue
|
||||
|
||||
phase_total = sum(s.total_us for s in non_total_subcats.values())
|
||||
print(f" {phase}:")
|
||||
for subcat, subcat_stats in sorted(
|
||||
non_total_subcats.items(),
|
||||
key=lambda x: x[1].total_us,
|
||||
reverse=True,
|
||||
):
|
||||
pct = (
|
||||
subcat_stats.total_us / phase_total * 100 if phase_total else 0
|
||||
)
|
||||
# Use parent phase's step count for per-step average
|
||||
phase_step_count = subcats.get("_total", CategoryStats()).count
|
||||
if phase_step_count > 0:
|
||||
avg_per_step = subcat_stats.total_us / phase_step_count
|
||||
else:
|
||||
avg_per_step = subcat_stats.avg_us # fallback
|
||||
print(
|
||||
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
|
||||
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
|
||||
)
|
||||
print()
|
||||
else:
|
||||
# Fall back to flat category display if no hierarchical categories
|
||||
print("By Category:")
|
||||
total_time = sum(c.total_us for c in stats.by_category.values())
|
||||
for category, cat_stats in sorted(
|
||||
stats.by_category.items(), key=lambda x: x[1].total_us, reverse=True
|
||||
):
|
||||
pct = (cat_stats.total_us / total_time * 100) if total_time > 0 else 0
|
||||
print(
|
||||
f" {category:12s} {_format_duration(cat_stats.total_us):>10s} "
|
||||
f"({pct:5.1f}%) avg: {_format_duration(cat_stats.avg_us):>8s} "
|
||||
f"count: {cat_stats.count}"
|
||||
)
|
||||
print()
|
||||
|
||||
# By Rank
|
||||
if stats.by_rank:
|
||||
print("By Rank:")
|
||||
for rank in sorted(stats.by_rank.keys()):
|
||||
rank_stats = stats.by_rank[rank]
|
||||
print(f" Rank {rank}:")
|
||||
|
||||
# Parse hierarchical categories for this rank
|
||||
rank_phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
|
||||
has_hierarchical = False
|
||||
for cat, cat_stats in rank_stats.items():
|
||||
if "/" in cat:
|
||||
phase, subcat = cat.split("/", 1)
|
||||
rank_phases[phase][subcat] = cat_stats
|
||||
has_hierarchical = True
|
||||
else:
|
||||
rank_phases[cat]["_total"] = cat_stats
|
||||
|
||||
if has_hierarchical:
|
||||
for phase in sorted(rank_phases.keys()):
|
||||
subcats = rank_phases[phase]
|
||||
non_total_subcats = {
|
||||
k: v for k, v in subcats.items() if k != "_total"
|
||||
}
|
||||
if not non_total_subcats:
|
||||
continue
|
||||
|
||||
phase_total = sum(s.total_us for s in non_total_subcats.values())
|
||||
print(f" {phase}:")
|
||||
for subcat, subcat_stats in sorted(
|
||||
non_total_subcats.items(),
|
||||
key=lambda x: x[1].total_us,
|
||||
reverse=True,
|
||||
):
|
||||
pct = (
|
||||
subcat_stats.total_us / phase_total * 100
|
||||
if phase_total
|
||||
else 0
|
||||
)
|
||||
# Use parent phase's step count for per-step average
|
||||
phase_step_count = subcats.get("_total", CategoryStats()).count
|
||||
if phase_step_count > 0:
|
||||
avg_per_step = subcat_stats.total_us / phase_step_count
|
||||
else:
|
||||
avg_per_step = subcat_stats.avg_us # fallback
|
||||
print(
|
||||
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
|
||||
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
|
||||
)
|
||||
else:
|
||||
# Flat display fallback
|
||||
for category, cat_stats in sorted(
|
||||
rank_stats.items(), key=lambda x: x[1].total_us, reverse=True
|
||||
):
|
||||
print(f" {category}: {_format_duration(cat_stats.total_us)}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("trace.json")
|
||||
|
||||
if not path.exists():
|
||||
print(f"Error: File not found: {path}")
|
||||
sys.exit(1)
|
||||
|
||||
traces = load_trace_file(path)
|
||||
if not traces:
|
||||
print("No trace events found in file.")
|
||||
sys.exit(0)
|
||||
|
||||
computed_stats = compute_stats(traces)
|
||||
print_stats(computed_stats)
|
||||
@@ -11,7 +11,7 @@ from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, ConfigDict, TaggedModel
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
@@ -116,8 +116,8 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: PromptTokensDetails | None = None
|
||||
completion_tokens_details: CompletionTokensDetails | None = None
|
||||
prompt_tokens_details: PromptTokensDetails
|
||||
completion_tokens_details: CompletionTokensDetails
|
||||
|
||||
|
||||
class StreamingChoiceResponse(BaseModel):
|
||||
@@ -170,7 +170,13 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: bool = False
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(TaggedModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
messages: list[ChatCompletionMessage]
|
||||
@@ -184,6 +190,7 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
@@ -366,3 +373,45 @@ class StartDownloadResponse(CamelCaseModel):
|
||||
|
||||
class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class TraceEventResponse(CamelCaseModel):
|
||||
name: str
|
||||
start_us: int
|
||||
duration_us: int
|
||||
rank: int
|
||||
category: str
|
||||
|
||||
|
||||
class TraceResponse(CamelCaseModel):
|
||||
task_id: str
|
||||
traces: list[TraceEventResponse]
|
||||
|
||||
|
||||
class TraceCategoryStats(CamelCaseModel):
|
||||
total_us: int
|
||||
count: int
|
||||
min_us: int
|
||||
max_us: int
|
||||
avg_us: float
|
||||
|
||||
|
||||
class TraceRankStats(CamelCaseModel):
|
||||
by_category: dict[str, TraceCategoryStats]
|
||||
|
||||
|
||||
class TraceStatsResponse(CamelCaseModel):
|
||||
task_id: str
|
||||
total_wall_time_us: int
|
||||
by_category: dict[str, TraceCategoryStats]
|
||||
by_rank: dict[int, TraceRankStats]
|
||||
|
||||
|
||||
class TraceListItem(CamelCaseModel):
|
||||
task_id: str
|
||||
created_at: str
|
||||
file_size: int
|
||||
|
||||
|
||||
class TraceListResponse(CamelCaseModel):
|
||||
traces: list[TraceListItem]
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -17,6 +17,7 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -28,6 +29,7 @@ class ErrorChunk(BaseChunk):
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["tool_calls"] = "tool_calls"
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
@@ -22,7 +23,7 @@ class TestCommand(BaseCommand):
|
||||
|
||||
|
||||
class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
request_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -10,7 +11,7 @@ from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||
|
||||
|
||||
class EventId(Id):
|
||||
@@ -109,6 +110,22 @@ class TopologyEdgeDeleted(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
|
||||
@final
|
||||
class TraceEventData(FrozenModel):
|
||||
name: str
|
||||
start_us: int
|
||||
duration_us: int
|
||||
rank: int
|
||||
category: str
|
||||
|
||||
|
||||
@final
|
||||
class TracesCollected(BaseEvent):
|
||||
task_id: TaskId
|
||||
rank: int
|
||||
traces: list[TraceEventData]
|
||||
|
||||
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
@@ -127,6 +144,7 @@ Event = (
|
||||
| InputChunkReceived
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
@@ -54,7 +55,7 @@ class StartWarmup(BaseTask): # emitted by Worker
|
||||
|
||||
class ChatCompletion(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ChatCompletionTaskParams
|
||||
task_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
@@ -6,6 +6,7 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
@@ -24,6 +25,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
@@ -57,6 +59,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -1,45 +1,31 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
|
||||
def find_resources() -> Path:
|
||||
resources = _find_resources_in_repo() or _find_resources_in_bundle()
|
||||
if resources is None:
|
||||
raise FileNotFoundError(
|
||||
"Unable to locate resources. Did you clone the repo properly?"
|
||||
)
|
||||
return resources
|
||||
|
||||
|
||||
def _find_resources_in_repo() -> Path | None:
|
||||
current_module = Path(__file__).resolve()
|
||||
for parent in current_module.parents:
|
||||
build = parent / "resources"
|
||||
if build.is_dir():
|
||||
return build
|
||||
return None
|
||||
|
||||
|
||||
def _find_resources_in_bundle() -> Path | None:
|
||||
frozen_root = cast(str | None, getattr(sys, "_MEIPASS", None))
|
||||
if frozen_root is None:
|
||||
return None
|
||||
candidate = Path(frozen_root) / "resources"
|
||||
if candidate.is_dir():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def find_dashboard() -> Path:
|
||||
dashboard = _find_dashboard_in_repo() or _find_dashboard_in_bundle()
|
||||
dashboard = (
|
||||
_find_dashboard_in_env()
|
||||
or _find_dashboard_in_repo()
|
||||
or _find_dashboard_in_bundle()
|
||||
)
|
||||
if not dashboard:
|
||||
raise FileNotFoundError(
|
||||
"Unable to locate dashboard assets - you probably forgot to run `cd dashboard && npm install && npm run build && cd ..`"
|
||||
"Unable to locate dashboard assets - make sure the dashboard has been built, or export DASHBOARD_DIR if you've built the dashboard elsewhere."
|
||||
)
|
||||
return dashboard
|
||||
|
||||
|
||||
def _find_dashboard_in_env() -> Path | None:
|
||||
env = os.environ.get("DASHBOARD_DIR")
|
||||
if not env:
|
||||
return None
|
||||
resolved_env = Path(env).expanduser().resolve()
|
||||
|
||||
return resolved_env
|
||||
|
||||
|
||||
def _find_dashboard_in_repo() -> Path | None:
|
||||
current_module = Path(__file__).resolve()
|
||||
for parent in current_module.parents:
|
||||
|
||||
@@ -98,8 +98,8 @@ def generate_image(
|
||||
|
||||
partial_images = (
|
||||
task.partial_images
|
||||
if task.partial_images is not None
|
||||
else (3 if task.stream else 0)
|
||||
if task.partial_images is not None and task.stream is not None and task.stream
|
||||
else 0
|
||||
)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
@@ -6,6 +6,11 @@ from mflux.models.common.config.config import Config
|
||||
from mflux.utils.exceptions import StopImageGenerationException
|
||||
from tqdm import tqdm
|
||||
|
||||
from exo.shared.tracing import (
|
||||
clear_trace_buffer,
|
||||
is_tracing_enabled,
|
||||
trace,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
@@ -324,6 +329,7 @@ class DiffusionRunner:
|
||||
capture_steps = set()
|
||||
|
||||
self._reset_all_caches()
|
||||
clear_trace_buffer()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
|
||||
@@ -348,6 +354,7 @@ class DiffusionRunner:
|
||||
ctx.in_loop( # pyright: ignore[reportAny]
|
||||
t=t,
|
||||
latents=latents,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
@@ -464,20 +471,22 @@ class DiffusionRunner:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
with trace(name=f"sync {t}", rank=self.rank, category="sync"):
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
else:
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
is_first_async_step=t == config.init_time_step + num_sync_steps,
|
||||
)
|
||||
with trace(name=f"async {t}", rank=self.rank, category="async"):
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
is_first_async_step=t == config.init_time_step + num_sync_steps,
|
||||
)
|
||||
|
||||
def _single_node_step(
|
||||
self,
|
||||
@@ -585,30 +594,41 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
with trace(
|
||||
name="joint_blocks",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(encoder_hidden_states, hidden_states)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -619,45 +639,63 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
with trace(
|
||||
name="single blocks",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(hidden_states)
|
||||
|
||||
if not self.is_last_stage:
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
@@ -741,14 +779,20 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
with trace(name="send 0", rank=self.rank, category="comms"):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, 0, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
with trace(
|
||||
name=f"recv {self.world_size - 1}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
@@ -808,10 +852,13 @@ class DiffusionRunner:
|
||||
and not self.is_last_stage
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
@@ -842,10 +889,13 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
@@ -884,22 +934,28 @@ class DiffusionRunner:
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -908,14 +964,22 @@ class DiffusionRunner:
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
with trace(
|
||||
name=f"joint patch {patch_idx}",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(encoder_hidden_states, patch)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -924,49 +988,70 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
with trace(
|
||||
name=f"single patch {patch_idx}",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(patch)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
|
||||
patch = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch, text_embeddings)
|
||||
|
||||
return noise, encoder_hidden_states
|
||||
|
||||
@@ -201,6 +201,9 @@ def pipeline_auto_parallel(
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
for layer in layers:
|
||||
mx.eval(layer) # type: ignore
|
||||
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[-1],
|
||||
|
||||
@@ -3,6 +3,7 @@ from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
@@ -12,25 +13,29 @@ from mlx_lm.models.cache import (
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.85
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.9
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, tokenizer: TokenizerWrapper):
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
self._group = group
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
@@ -81,13 +86,13 @@ class KVPrefixCache:
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = _cache_length(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
@@ -109,7 +114,7 @@ class KVPrefixCache:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
@@ -131,29 +136,37 @@ class KVPrefixCache:
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory pressure is high."""
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
active: int = mx.metal.get_active_memory()
|
||||
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
while len(self.caches) > 0:
|
||||
while (
|
||||
len(self.caches) > 1
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
)
|
||||
|
||||
active = mx.metal.get_active_memory()
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
break
|
||||
def get_memory_used_percentage(self) -> float:
|
||||
local_pressure: float = get_memory_used_percentage()
|
||||
|
||||
if self._group is None:
|
||||
return local_pressure
|
||||
|
||||
all_pressure = mx.distributed.all_gather(
|
||||
mx.array([local_pressure], dtype=mx.float32),
|
||||
group=self._group,
|
||||
)
|
||||
# .item() evals.
|
||||
max_pressure = float(mx.max(all_pressure).item())
|
||||
return max_pressure
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
@@ -168,13 +181,13 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def _cache_length(cache: KVCacheType) -> int:
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
@@ -185,6 +198,17 @@ def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def get_available_memory() -> Memory:
|
||||
mem: int = psutil.virtual_memory().available
|
||||
return Memory.from_bytes(mem)
|
||||
|
||||
|
||||
def get_memory_used_percentage() -> float:
|
||||
mem = psutil.virtual_memory()
|
||||
# percent is 0-100
|
||||
return float(mem.percent / 100)
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> KVCacheType:
|
||||
|
||||
@@ -10,8 +10,11 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
CompletionTokensDetails,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
PromptTokensDetails,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
@@ -39,7 +42,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> float:
|
||||
) -> tuple[float, int]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
@@ -50,7 +53,7 @@ def prefill(
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0
|
||||
return 0.0, 0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
@@ -85,7 +88,7 @@ def prefill(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec
|
||||
return tokens_per_sec, num_tokens
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -169,6 +172,8 @@ def mlx_generate(
|
||||
mx.reset_peak_memory()
|
||||
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
|
||||
|
||||
logger.info(f"{is_bench=}")
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
logger.debug(f"task_params: {task}")
|
||||
|
||||
@@ -204,7 +209,9 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
@@ -212,28 +219,43 @@ def mlx_generate(
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
usage: Usage | None = None
|
||||
in_thinking = False
|
||||
reasoning_tokens = 0
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
for completion_tokens, out in enumerate(
|
||||
stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
if think_start is not None and out.text == think_start:
|
||||
in_thinking = True
|
||||
elif think_end is not None and out.text == think_end:
|
||||
in_thinking = False
|
||||
if in_thinking:
|
||||
reasoning_tokens += 1
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
@@ -245,11 +267,24 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
completion_tokens_details=CompletionTokensDetails(
|
||||
reasoning_tokens=reasoning_tokens
|
||||
),
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
|
||||
@@ -18,6 +18,7 @@ from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer, is_tracing_enabled
|
||||
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
@@ -27,6 +28,8 @@ from exo.shared.types.events import (
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
@@ -37,6 +40,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -111,8 +115,12 @@ def main(
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
seen = set[TaskId]()
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -163,7 +171,7 @@ def main(
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer, group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
@@ -277,9 +285,11 @@ def main(
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
if (
|
||||
device_rank == 0
|
||||
and response.finish_reason == "error"
|
||||
@@ -307,6 +317,7 @@ def main(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
usage=response.usage,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
@@ -320,6 +331,7 @@ def main(
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
usage=response.usage,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -399,6 +411,10 @@ def main(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_send_traces_if_enabled(
|
||||
event_sender, task.task_id, shard_metadata.device_rank
|
||||
)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -457,6 +473,10 @@ def main(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_send_traces_if_enabled(
|
||||
event_sender, task.task_id, shard_metadata.device_rank
|
||||
)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -535,10 +555,10 @@ def parse_gpt_oss(
|
||||
name=current_tool_name,
|
||||
arguments="".join(tool_arg_parts).strip(),
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=response.usage,
|
||||
)
|
||||
tool_arg_parts = []
|
||||
break
|
||||
current_tool_name = recipient
|
||||
|
||||
# If inside a tool call, accumulate arguments
|
||||
@@ -631,6 +651,36 @@ def _send_image_chunk(
|
||||
)
|
||||
|
||||
|
||||
def _send_traces_if_enabled(
|
||||
event_sender: MpSender[Event],
|
||||
task_id: TaskId,
|
||||
rank: int,
|
||||
) -> None:
|
||||
if not is_tracing_enabled():
|
||||
return
|
||||
|
||||
traces = get_trace_buffer()
|
||||
if traces:
|
||||
trace_data = [
|
||||
TraceEventData(
|
||||
name=t.name,
|
||||
start_us=t.start_us,
|
||||
duration_us=t.duration_us,
|
||||
rank=t.rank,
|
||||
category=t.category,
|
||||
)
|
||||
for t in traces
|
||||
]
|
||||
event_sender.send(
|
||||
TracesCollected(
|
||||
task_id=task_id,
|
||||
rank=rank,
|
||||
traces=trace_data,
|
||||
)
|
||||
)
|
||||
clear_trace_buffer()
|
||||
|
||||
|
||||
def _process_image_response(
|
||||
response: ImageGenerationResponse | PartialImageResponse,
|
||||
command_id: CommandId,
|
||||
@@ -684,7 +734,7 @@ def parse_tool_calls(
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
|
||||
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
|
||||
@@ -127,20 +127,25 @@ class RunnerSupervisor:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
f"Skipping invalid task {task} as it has already been submitted"
|
||||
)
|
||||
return
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
logger.warning(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
try:
|
||||
self._task_sender.send(task)
|
||||
await self._task_sender.send_async(task)
|
||||
except ClosedResourceError:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
return
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
|
||||
@@ -14,9 +14,9 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVPrefixCache,
|
||||
_cache_length,
|
||||
_get_prefix_length,
|
||||
cache_length,
|
||||
encode_prompt,
|
||||
get_prefix_length,
|
||||
make_kv_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
|
||||
@@ -35,47 +35,47 @@ class TestGetPrefixLength:
|
||||
def test_identical_arrays(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert _get_prefix_length(a, b) == 5
|
||||
assert get_prefix_length(a, b) == 5
|
||||
|
||||
def test_no_common_prefix(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([4, 5, 6])
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_partial_prefix(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 7, 8])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_prompt_longer_than_cached(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_cached_longer_than_prompt(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_single_token_match(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 5, 6])
|
||||
assert _get_prefix_length(a, b) == 1
|
||||
assert get_prefix_length(a, b) == 1
|
||||
|
||||
def test_empty_prompt(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([1, 2, 3])
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_empty_cached(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_both_empty(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
|
||||
class TestKVPrefix:
|
||||
@@ -146,7 +146,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert _cache_length(cache) == len(tokens)
|
||||
assert cache_length(cache) == len(tokens)
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
@@ -166,7 +166,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
assert stored_length > 0
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
@@ -209,7 +209,7 @@ class TestKVPrefixCacheWithModel:
|
||||
long_tokens = encode_prompt(tokenizer, long_prompt)
|
||||
|
||||
# The prompts share a prefix (chat template preamble + "Hi")
|
||||
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
|
||||
expected_prefix = get_prefix_length(long_tokens, short_tokens)
|
||||
assert expected_prefix > 0, (
|
||||
"Prompts should share a prefix from the chat template"
|
||||
)
|
||||
@@ -243,7 +243,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
@@ -259,7 +259,7 @@ class TestKVPrefixCacheWithModel:
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
# Stored cache must be unchanged
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
|
||||
def test_stored_cache_survives_repeated_get_mutate_cycles(
|
||||
self, model_and_tokenizer
|
||||
@@ -281,7 +281,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
@@ -293,7 +293,7 @@ class TestKVPrefixCacheWithModel:
|
||||
layer_cache.update_and_fetch(extra, extra)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
f"Failed on loop {i}"
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ class TestKVPrefixCacheWithModel:
|
||||
assert len(kv_prefix_cache.caches) == 1
|
||||
# Cache should contain prompt + generated tokens
|
||||
expected_length = len(prompt_tokens) + generated_tokens
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
|
||||
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
@@ -400,7 +400,7 @@ class TestKVPrefixCacheWithModel:
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
first_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation: same long prompt + extra content (simulating multi-turn)
|
||||
task2 = ChatCompletionTaskParams(
|
||||
@@ -416,7 +416,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt2_tokens = encode_prompt(tokenizer, prompt2)
|
||||
|
||||
# Verify the prompts share a long prefix
|
||||
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
|
||||
|
||||
# Second generation should reuse the cached prefix (only prefill new tokens)
|
||||
@@ -440,7 +440,7 @@ class TestKVPrefixCacheWithModel:
|
||||
# With prefix_hit > 1000, should update in-place (not add a second entry)
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
|
||||
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
assert updated_cache_length > first_cache_length
|
||||
|
||||
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
|
||||
@@ -465,7 +465,7 @@ class TestKVPrefixCacheWithModel:
|
||||
):
|
||||
pass
|
||||
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
firstcache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation gets the cache and mutates it during generation
|
||||
for _response in mlx_generate(
|
||||
@@ -478,7 +478,7 @@ class TestKVPrefixCacheWithModel:
|
||||
pass
|
||||
|
||||
# The first stored cache must not have been mutated by the second generation
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
|
||||
|
||||
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
@@ -540,6 +540,6 @@ class TestKVPrefixCacheWithModel:
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ from exo.download.download_utils import (
|
||||
ensure_models_dir,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
get_eos_token_ids_for_model,
|
||||
load_tokenizer_for_model_id,
|
||||
@@ -76,7 +76,7 @@ def get_test_models() -> list[ModelCard]:
|
||||
"""Get a representative sample of models to test."""
|
||||
# Pick one model from each family to test
|
||||
families: dict[str, ModelCard] = {}
|
||||
for card in asyncio.run(get_model_cards()):
|
||||
for card in MODEL_CARDS.values():
|
||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||
parts = card.model_id.short().split("-")
|
||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||
@@ -296,7 +296,7 @@ async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
|
||||
async def test_kimi_tokenizer_specifically():
|
||||
"""Test Kimi tokenizer with its specific patches and quirks."""
|
||||
kimi_models = [
|
||||
card for card in await get_model_cards() if "kimi" in card.model_id.lower()
|
||||
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
|
||||
]
|
||||
|
||||
if not kimi_models:
|
||||
@@ -343,7 +343,7 @@ async def test_kimi_tokenizer_specifically():
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
glm_model_cards = [
|
||||
card for card in await get_model_cards() if "glm" in card.model_id.lower()
|
||||
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
|
||||
]
|
||||
|
||||
if not glm_model_cards:
|
||||
|
||||
@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
# initialize_mlx returns a mock group
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
@@ -120,7 +120,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
@@ -147,6 +147,14 @@ class MockTokenizer:
|
||||
has_tool_calling = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
def rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def size(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -182,6 +190,8 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
text="hi",
|
||||
token_id=0,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
stats=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from exo.download.impl_shard_downloader import (
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
@@ -89,26 +89,22 @@ async def tb_detection():
|
||||
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(ModelId("mlx-community/Qwen3-0.6B-4bit"))
|
||||
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"))
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"))
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"))
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"))
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(ModelId("mlx-community/GLM-4.7-8bit-gs32"))
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(ModelId("mlx-community/MiniMax-M2.1-8bit"))
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
)
|
||||
|
||||
|
||||
|
||||
18
tmp/config_examples/opencode.json
Normal file
18
tmp/config_examples/opencode.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
|
||||
"provider": {
|
||||
"exo": {
|
||||
"api": "http://localhost:52415/v1",
|
||||
"models": {
|
||||
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
|
||||
"name": "GPT OSS 120B",
|
||||
"limit": {
|
||||
"context": 32768,
|
||||
"output": 8192
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user