download fixes

This commit is contained in:
Alex Cheema
2025-10-22 11:56:52 +01:00
committed by GitHub
parent 363c98a872
commit a346af3477
17 changed files with 1240 additions and 961 deletions

View File

@@ -1,154 +0,0 @@
# name: Build and Release Exo macOS App
# on:
# push:
# tags:
# - 'v*' # Trigger on version tags
# branches:
# - main # Also build on main branch for testing
# - staging
# - python-modules # Add app-staging for testing
# pull_request:
# branches:
# - staging # Test builds on PRs to staging
# - main # Build on PRs to main
# jobs:
# build-exov2-macos:
# runs-on: macos-15
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
# - name: Install Go
# uses: actions/setup-go@v5
# with:
# go-version: '1.21'
# - name: Install Just
# run: |
# brew install just
# - name: Install UV
# uses: astral-sh/setup-uv@v6
# with:
# enable-cache: true
# cache-dependency-glob: uv.lock
# - name: Setup Python Environment
# run: |
# uv python install
# uv sync --locked --all-extras
# - name: Verify Python Environment
# run: |
# uv run python -c "import master.main; print('Master module available')"
# uv run python -c "import worker.main; print('Worker module available')"
# - name: Prepare Code Signing Keychain
# env:
# MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }}
# MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
# PROVISIONING_PROFILE: ${{ secrets.PROVISIONING_PROFILE }}
# run: |
# security create-keychain -p "$MACOS_CERTIFICATE_PASSWORD" exov2.keychain
# security default-keychain -s exov2.keychain
# security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" exov2.keychain
# echo "$MACOS_CERTIFICATE" | base64 --decode > /tmp/exov2-certificate.p12
# security import /tmp/exov2-certificate.p12 -k exov2.keychain -P "$MACOS_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
# rm /tmp/exov2-certificate.p12
# security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$MACOS_CERTIFICATE_PASSWORD" exov2.keychain
# PROFILES_HOME="$HOME/Library/Developer/Xcode/UserData/Provisioning Profiles"
# mkdir -p "$PROFILES_HOME"
# PROFILE_PATH="$(mktemp "$PROFILES_HOME"/EXOV2_PP.provisionprofile)"
# echo "$PROVISIONING_PROFILE" | base64 --decode > "$PROFILE_PATH"
# - name: Build Exo Swift App
# env:
# MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
# run: |
# cd app/exov2
# sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
# # Release build with code signing
# security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" exov2.keychain
# SIGNING_IDENTITY=$(security find-identity -v -p codesigning | awk -F '"' '{print $2}')
# xcodebuild clean build \
# -project exov2.xcodeproj \
# -scheme exov2 \
# -configuration Release \
# -derivedDataPath build \
# CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
# PROVISIONING_PROFILE_SPECIFIER="Exo Provisioning Profile" \
# CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES \
# OTHER_CODE_SIGN_FLAGS="--timestamp"
# mv build/Build/Products/*/EXO.app ../../
# - name: Sign, Notarize, and Create DMG
# env:
# APPLE_NOTARIZATION_USERNAME: ${{ secrets.APPLE_NOTARIZATION_USERNAME }}
# APPLE_NOTARIZATION_PASSWORD: ${{ secrets.APPLE_NOTARIZATION_PASSWORD }}
# APPLE_NOTARIZATION_TEAM: ${{ secrets.APPLE_NOTARIZATION_TEAM }}
# MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
# run: |
# security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" exov2.keychain
# SIGNING_IDENTITY=$(security find-identity -v -p codesigning | awk -F '"' '{print $2}')
# # Sign the app
# /usr/bin/codesign --deep --force --timestamp --options runtime \
# --sign "$SIGNING_IDENTITY" EXO.app
# # Verify the signing
# codesign -dvv EXO.app
# # Create DMG
# mkdir -p tmp/dmg-contents
# cp -r ./EXO.app tmp/dmg-contents/
# ln -s /Applications tmp/dmg-contents/Applications
# DMG_NAME="exo.dmg"
# # Create and sign DMG
# hdiutil create -volname "Exo" -srcfolder tmp/dmg-contents -ov -format UDZO "$DMG_NAME"
# /usr/bin/codesign --deep --force --timestamp --options runtime \
# --sign "$SIGNING_IDENTITY" "$DMG_NAME"
# # Setup notarization credentials (optional - comment out if no notarization secrets)
# if [[ -n "$APPLE_NOTARIZATION_USERNAME" ]]; then
# xcrun notarytool store-credentials notary_pass \
# --apple-id "$APPLE_NOTARIZATION_USERNAME" \
# --password "$APPLE_NOTARIZATION_PASSWORD" \
# --team-id "$APPLE_NOTARIZATION_TEAM"
# # Submit for notarization
# xcrun notarytool submit --wait \
# --team-id "$APPLE_NOTARIZATION_TEAM" \
# --keychain-profile notary_pass \
# "$DMG_NAME"
# # Staple the notarization
# xcrun stapler staple "$DMG_NAME"
# fi
# - name: Cleanup Keychain
# if: always()
# run: |
# security default-keychain -s login.keychain
# security delete-keychain exov2.keychain
# - name: Upload DMG file
# uses: actions/upload-artifact@v4
# with:
# name: exo-dmg
# path: exo.dmg
# - name: Upload App Bundle
# uses: actions/upload-artifact@v4
# with:
# name: exov2-app
# path: EXO.app/

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

BIN
dashboard/favicon.ico Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

View File

@@ -492,6 +492,162 @@
transition: width 0.3s ease;
}
/* Detailed download info */
.download-details {
margin-top: 8px;
padding: 12px;
background-color: #1a1a1a;
border: 1px solid var(--exo-medium-gray);
border-radius: 6px;
box-sizing: border-box;
width: 100%;
max-width: 100%;
overflow: visible;
}
.download-runner-header {
font-size: 11px;
color: var(--exo-light-gray);
opacity: 0.85;
margin-bottom: 4px;
}
.download-overview-row {
display: flex;
gap: 12px;
flex-wrap: wrap;
font-size: 12px;
margin-bottom: 8px;
}
.download-overview-item strong {
color: #E0E0E0;
font-weight: 600;
margin-right: 4px;
}
.progress-with-label {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 10px;
}
.progress-with-label .progress-bar-container {
flex: 1 1 auto;
}
.progress-percent {
font-size: 12px;
color: var(--exo-light-gray);
font-variant-numeric: tabular-nums;
white-space: nowrap;
}
.download-overview-combined {
font-size: 12px;
color: var(--exo-light-gray);
opacity: 0.9;
}
.instance-download-summary {
font-size: 11px;
color: var(--exo-light-gray);
margin-top: 6px;
opacity: 0.95;
}
.download-files-list {
display: grid;
gap: 8px;
}
.download-file {
padding: 8px;
background-color: var(--exo-dark-gray);
border: 1px solid var(--exo-medium-gray);
border-radius: 6px;
box-sizing: border-box;
width: 100%;
max-width: 100%;
}
.download-file-header {
display: flex;
justify-content: space-between;
align-items: center;
gap: 10px;
font-size: 11px;
margin-bottom: 6px;
width: 100%;
max-width: 100%;
overflow: hidden;
}
.download-file-name {
color: #E0E0E0;
font-weight: 500;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
min-width: 0;
flex: 1 1 auto;
}
.download-file-stats {
color: var(--exo-light-gray);
text-align: right;
white-space: nowrap;
}
.download-file-percent {
color: var(--exo-light-gray);
white-space: nowrap;
font-size: 11px;
font-variant-numeric: tabular-nums;
flex: 0 0 auto;
}
.download-file-subtext {
color: var(--exo-light-gray);
font-size: 10px;
opacity: 0.85;
margin-bottom: 6px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
max-width: 100%;
}
.download-details, .download-files-list {
box-sizing: border-box;
width: 100%;
max-width: 100%;
}
.download-files-list {
overflow: visible;
padding-right: 2px; /* avoid edge clipping */
}
.download-file .progress-bar-container {
width: 100%;
max-width: 100%;
box-sizing: border-box;
height: 5px;
}
.completed-files-section {
margin-top: 12px;
padding-top: 8px;
border-top: 1px solid var(--exo-medium-gray);
}
.completed-files-header {
font-size: 10px;
color: var(--exo-light-gray);
opacity: 0.7;
margin-bottom: 6px;
font-weight: 500;
}
.completed-files-list {
display: flex;
flex-direction: column;
gap: 3px;
}
.completed-file-item {
font-size: 10px;
color: var(--exo-light-gray);
opacity: 0.8;
padding: 3px 6px;
background-color: rgba(74, 222, 128, 0.1);
border-left: 2px solid #4ade80;
border-radius: 3px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
/* Launch instance section styles */
.launch-instance-section {
display: flex;
@@ -750,6 +906,7 @@
const USE_MOCK_DATA = false; // <<< FLAG TO TOGGLE MOCK DATA
let currentlySelectedNodeId = null; // To store the ID of the currently selected node
let nodeIdToFriendlyName = {}; // Map nodeId -> friendly name for download sections
const API_ENDPOINT = window.location.origin + window.location.pathname.replace(/\/$/, "") + '/state';
const REFRESH_INTERVAL = 1000; // 1 second
@@ -855,6 +1012,36 @@
return days + (days === 1 ? ' day ago' : ' days ago');
}
// --- Download formatting helpers ---
function bytesFromValue(value) {
if (typeof value === 'number') return value;
if (!value || typeof value !== 'object') return 0;
if (typeof value.in_bytes === 'number') return value.in_bytes;
if (typeof value.inBytes === 'number') return value.inBytes;
return 0;
}
function formatDurationMs(ms) {
if (ms == null || isNaN(ms) || ms < 0) return '—';
const totalSeconds = Math.round(ms / 1000);
const s = totalSeconds % 60;
const m = Math.floor(totalSeconds / 60) % 60;
const h = Math.floor(totalSeconds / 3600);
if (h > 0) return `${h}h ${m}m ${s}s`;
if (m > 0) return `${m}m ${s}s`;
return `${s}s`;
}
function formatPercent(value, digits = 2) {
if (value == null || isNaN(value)) return '0.00%';
return `${value.toFixed(digits)}%`;
}
function formatBytesPerSecond(bps) {
if (bps == null || isNaN(bps) || bps < 0) return '0 B/s';
return `${formatBytes(bps)}/s`;
}
// Sidebar toggle functionality
let sidebarOpen = false;
@@ -934,7 +1121,7 @@
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ modelId: selectedModelId, model_id: selectedModelId })
body: JSON.stringify({ model_id: selectedModelId })
});
if (!response.ok) {
@@ -974,66 +1161,185 @@
}
}
// Calculate download status for an instance based on its runners
// Calculate download status for an instance based on its runners, with detailed per-file info
function calculateInstanceDownloadStatus(instance, runners) {
const shardAssignments = instance.shard_assignments ?? instance.shardAssignments;
const runnerToShard = shardAssignments?.runner_to_shard ?? shardAssignments?.runnerToShard;
if (!runnerToShard || !runners) {
return { isDownloading: false, progress: 0 };
if (!instance.shardAssignments?.runnerToShard || !runners) {
return { isDownloading: false, progress: 0, details: [] };
}
const runnerIds = Object.keys(runnerToShard);
const downloadingRunners = [];
const pick = (obj, snake, camel, fallback = undefined) => {
if (!obj) return fallback;
if (obj[snake] !== undefined) return obj[snake];
if (obj[camel] !== undefined) return obj[camel];
return fallback;
};
// Returns [tag, payload] for objects serialized as {Tag: {...}}, else [null, null]
function getTagged(obj) {
if (!obj || typeof obj !== 'object') return [null, null];
const keys = Object.keys(obj);
if (keys.length === 1 && typeof keys[0] === 'string') {
return [keys[0], obj[keys[0]]];
}
return [null, null];
}
function normalizeProgress(progressRaw) {
if (!progressRaw) return null;
const totalBytes = bytesFromValue(pick(progressRaw, 'total_bytes', 'totalBytes', 0));
const downloadedBytes = bytesFromValue(pick(progressRaw, 'downloaded_bytes', 'downloadedBytes', 0));
const downloadedBytesThisSession = bytesFromValue(pick(progressRaw, 'downloaded_bytes_this_session', 'downloadedBytesThisSession', 0));
const completedFiles = Number(pick(progressRaw, 'completed_files', 'completedFiles', 0)) || 0;
const totalFiles = Number(pick(progressRaw, 'total_files', 'totalFiles', 0)) || 0;
const speed = Number(pick(progressRaw, 'speed', 'speed', 0)) || 0;
const etaMs = Number(pick(progressRaw, 'eta_ms', 'etaMs', 0)) || 0;
const filesObj = pick(progressRaw, 'files', 'files', {}) || {};
const files = [];
Object.keys(filesObj).forEach(name => {
const f = filesObj[name];
if (!f || typeof f !== 'object') return;
const fTotal = bytesFromValue(pick(f, 'total_bytes', 'totalBytes', 0));
const fDownloaded = bytesFromValue(pick(f, 'downloaded_bytes', 'downloadedBytes', 0));
const fSpeed = Number(pick(f, 'speed', 'speed', 0)) || 0;
const fEta = Number(pick(f, 'eta_ms', 'etaMs', 0)) || 0;
const fPct = fTotal > 0 ? (fDownloaded / fTotal) * 100 : 0;
files.push({ name, totalBytes: fTotal, downloadedBytes: fDownloaded, speed: fSpeed, etaMs: fEta, percentage: fPct });
});
const percentage = totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0;
return { totalBytes, downloadedBytes, downloadedBytesThisSession, completedFiles, totalFiles, speed, etaMs, files, percentage };
}
const runnerIds = Object.keys(instance.shardAssignments.runnerToShard);
const details = [];
let totalBytes = 0;
let downloadedBytes = 0;
for (const runnerId of runnerIds) {
const runner = runners[runnerId];
let isRunnerDownloading = false;
if (!runner) continue;
// Legacy snake_case structure
if (runner && runner.runner_status === 'Downloading' && runner.download_progress) {
isRunnerDownloading = runner.download_progress.download_status === 'Downloading';
if (isRunnerDownloading && runner.download_progress.download_progress) {
totalBytes += runner.download_progress.download_progress.total_bytes || 0;
downloadedBytes += runner.download_progress.download_progress.downloaded_bytes || 0;
}
} else if (runner && typeof runner === 'object') {
// Tagged-union camelCase structure, e.g. { "DownloadingRunnerStatus": { downloadProgress: { totalBytes, downloadedBytes } } }
const tag = Object.keys(runner)[0];
if (tag && /DownloadingRunnerStatus$/i.test(tag)) {
isRunnerDownloading = true;
const inner = runner[tag] || {};
const prog = inner.downloadProgress || inner.download_progress || {};
const t = prog.totalBytes ?? prog.total_bytes ?? 0;
const d = prog.downloadedBytes ?? prog.downloaded_bytes ?? 0;
totalBytes += typeof t === 'number' ? t : 0;
downloadedBytes += typeof d === 'number' ? d : 0;
}
// New tagged format: { "DownloadingRunnerStatus": { downloadProgress: { "DownloadOngoing": { ... } } } }
const [statusKind, statusPayload] = getTagged(runner);
let nodeId;
let rawProg;
if (statusKind === 'DownloadingRunnerStatus') {
const dpTagged = statusPayload && (statusPayload.downloadProgress || statusPayload.download_progress);
const [dpKind, dpPayload] = getTagged(dpTagged);
if (dpKind !== 'DownloadOngoing') continue;
nodeId = (dpPayload && (dpPayload.nodeId || dpPayload.node_id)) || undefined;
rawProg = pick(dpPayload, 'download_progress', 'downloadProgress', null);
} else {
// Backward compatibility with old flat shape
if (runner.runnerStatus !== 'Downloading' || !runner.downloadProgress) continue;
const dp = runner.downloadProgress;
const isDownloading = (dp.downloadStatus === 'Downloading') || (dp.download_status === 'Downloading');
if (!isDownloading) continue;
nodeId = (dp && (dp.nodeId || dp.node_id)) || undefined;
rawProg = pick(dp, 'download_progress', 'downloadProgress', null);
}
if (isRunnerDownloading) downloadingRunners.push(runner);
const normalized = normalizeProgress(rawProg);
if (!normalized) continue;
details.push({ runnerId, nodeId, progress: normalized });
totalBytes += normalized.totalBytes || 0;
downloadedBytes += normalized.downloadedBytes || 0;
}
const isDownloading = downloadingRunners.length > 0;
const progress = totalBytes > 0 ? Math.round((downloadedBytes / totalBytes) * 100) : 0;
const isDownloadingAny = details.length > 0;
const progress = totalBytes > 0 ? ((downloadedBytes / totalBytes) * 100) : 0;
return { isDownloading: isDownloadingAny, progress, details };
}
return { isDownloading, progress, downloadingRunners: downloadingRunners.length };
function buildDownloadDetailsHTML(details) {
if (!details || details.length === 0) return '';
function shortId(id) { return (id && id.length > 8) ? id.slice(0, 8) + '…' : (id || ''); }
return details.map(({ runnerId, nodeId, progress }) => {
const etaStr = formatDurationMs(progress.etaMs);
const pctStr = formatPercent(progress.percentage || 0, 2);
const bytesStr = `${formatBytes(progress.downloadedBytes)} / ${formatBytes(progress.totalBytes)}`;
const speedStr = formatBytesPerSecond(progress.speed);
const filesSummary = `${progress.completedFiles}/${progress.totalFiles}`;
const allFiles = progress.files || [];
const inProgressFiles = allFiles.filter(f => (f.percentage || 0) < 100);
const completedFiles = allFiles.filter(f => (f.percentage || 0) >= 100);
const inProgressHTML = inProgressFiles.map(f => {
const fPct = f.percentage || 0;
const fBytes = `${formatBytes(f.downloadedBytes)} / ${formatBytes(f.totalBytes)}`;
const fEta = formatDurationMs(f.etaMs);
const fSpeed = formatBytesPerSecond(f.speed);
const pctText = formatPercent(fPct, 2);
return `
<div class="download-file">
<div class="download-file-header">
<span class="download-file-name" title="${f.name}">${f.name}</span>
<span class="download-file-percent">${pctText}</span>
</div>
<div class="download-file-subtext">${fBytes} • ETA ${fEta}${fSpeed}</div>
<div class="progress-bar-container"><div class="progress-bar" style="width: ${Math.max(0, Math.min(100, fPct)).toFixed(2)}%;"></div></div>
</div>
`;
}).join('');
const completedHTML = completedFiles.length > 0 ? `
<div class="completed-files-section">
<div class="completed-files-header">Completed (${completedFiles.length})</div>
<div class="completed-files-list">
${completedFiles.map(f => `<div class="completed-file-item" title="${f.name}">${f.name}</div>`).join('')}
</div>
</div>
` : '';
const runnerName = (nodeId && nodeIdToFriendlyName[nodeId]) ? nodeIdToFriendlyName[nodeId] : '?';
const headerText = `${runnerName} (${shortId(nodeId || '')})`;
return `
<div class="download-details">
<div class="download-runner-header">${headerText}</div>
<div class="download-files-list">
${inProgressHTML}
</div>
${completedHTML}
</div>
`;
}).join('');
}
// Derive a display status for an instance from its runners.
// Priority: FAILED > DOWNLOADING > STARTING > RUNNING > LOADED > INACTIVE
function deriveInstanceStatus(instance, runners = {}) {
const shardAssignments = instance.shard_assignments ?? instance.shardAssignments;
const runnerToShard = shardAssignments?.runner_to_shard ?? shardAssignments?.runnerToShard ?? {};
const runnerIds = Object.keys(runnerToShard);
const runnerIds = Object.keys(instance.shardAssignments?.runnerToShard || {});
function getTagged(obj) {
if (!obj || typeof obj !== 'object') return [null, null];
const keys = Object.keys(obj);
if (keys.length === 1 && typeof keys[0] === 'string') {
return [keys[0], obj[keys[0]]];
}
return [null, null];
}
function canonicalStatusFromKind(kind) {
const map = {
DownloadingRunnerStatus: 'Downloading',
InactiveRunnerStatus: 'Inactive',
StartingRunnerStatus: 'Starting',
LoadedRunnerStatus: 'Loaded',
RunningRunnerStatus: 'Running',
FailedRunnerStatus: 'Failed',
};
return map[kind] || null;
}
const statuses = runnerIds
.map(rid => {
const r = runners[rid];
if (!r || typeof r !== 'object') return undefined;
if (typeof r.runner_status === 'string') return r.runner_status;
const tag = Object.keys(r)[0];
return typeof tag === 'string' ? tag.replace(/RunnerStatus$/,'') : undefined; // e.g. LoadedRunnerStatus -> Loaded
if (!r) return null;
const [kind] = getTagged(r);
if (kind) return canonicalStatusFromKind(kind);
const s = r.runnerStatus;
return (typeof s === 'string') ? s : null; // backward compatibility
})
.filter(s => typeof s === 'string');
@@ -1041,8 +1347,8 @@
const every = (pred) => statuses.length > 0 && statuses.every(pred);
if (statuses.length === 0) {
const instanceType = instance.instance_type ?? instance.instanceType;
const inactive = instanceType === 'INACTIVE' || instanceType === 'Inactive';
const it = instance.instanceType;
const inactive = (it === 'Inactive' || it === 'INACTIVE');
return { statusText: inactive ? 'INACTIVE' : 'LOADED', statusClass: inactive ? 'inactive' : 'loaded' };
}
@@ -1072,12 +1378,10 @@
}
const instancesHTML = instancesArray.map(instance => {
const shardAssignments = instance.shard_assignments ?? instance.shardAssignments;
const modelId = shardAssignments?.model_id ?? shardAssignments?.modelId ?? 'Unknown Model';
const instanceId = instance.instance_id ?? instance.instanceId ?? '';
const truncatedInstanceId = instanceId.length > 8
? instanceId.substring(0, 8) + '...'
: instanceId;
const modelId = instance.shardAssignments?.modelId || 'Unknown Model';
const truncatedInstanceId = instance.instanceId.length > 8
? instance.instanceId.substring(0, 8) + '...'
: instance.instanceId;
const hostsHTML = instance.hosts?.map(host =>
`<span class="instance-host">${host.ip}:${host.port}</span>`
@@ -1094,15 +1398,31 @@
}
// Generate download progress HTML
const downloadProgressHTML = downloadStatus.isDownloading
? `<div class="download-progress">
<span>${downloadStatus.progress}% downloaded</span>
<div class="progress-bar-container">
<div class="progress-bar" style="width: ${downloadStatus.progress}%;"></div>
</div>
</div>`
: '';
let downloadProgressHTML = '';
let instanceDownloadSummary = '';
if (downloadStatus.isDownloading) {
const detailsHTML = buildDownloadDetailsHTML(downloadStatus.details || []);
const pctText = (downloadStatus.progress || 0).toFixed(2);
// Aggregate a compact summary from the first runner (they should be consistent in aggregate)
const first = (downloadStatus.details || [])[0]?.progress;
const etaStr = first ? formatDurationMs(first.etaMs) : '';
const bytesStr = first ? `${formatBytes(first.downloadedBytes)} / ${formatBytes(first.totalBytes)}` : '';
const speedStr = first ? formatBytesPerSecond(first.speed) : '';
const filesSummary = first ? `${first.completedFiles}/${first.totalFiles}` : '';
instanceDownloadSummary = `${etaStr} · ${bytesStr} · ${speedStr} · ${filesSummary} files`;
downloadProgressHTML = `
<div class="download-progress">
<span>${pctText}%</span>
<div class="progress-bar-container">
<div class="progress-bar" style="width: ${pctText}%;"></div>
</div>
</div>
${detailsHTML}
`;
}
const shardCount = Object.keys(instance.shardAssignments?.runnerToShard || {}).length;
return `
<div class="instance-item">
<div class="instance-header">
@@ -1111,15 +1431,14 @@
<span class="instance-status ${statusClass}">${statusText}</span>
</div>
<div class="instance-actions">
<button class="instance-delete-button" data-instance-id="${instanceId}" title="Delete Instance">
<button class="instance-delete-button" data-instance-id="${instance.instanceId}" title="Delete Instance">
Delete
</button>
</div>
</div>
<div class="instance-model">${modelId}</div>
<div class="instance-details">
Shards: ${Object.keys((shardAssignments?.runner_to_shard ?? shardAssignments?.runnerToShard) || {}).length}
</div>
<div class="instance-model">${modelId} <span style="color: var(--exo-light-gray); opacity: 0.8;">(${shardCount})</span></div>
${instanceDownloadSummary ? `<div class="instance-download-summary">${instanceDownloadSummary}</div>` : ''}
${downloadProgressHTML}
${hostsHTML ? `<div class="instance-hosts">${hostsHTML}</div>` : ''}
</div>
@@ -1176,10 +1495,12 @@
}
}
function renderNodes(nodesData) {
function renderNodes(topologyData) {
if (!topologyGraphContainer) return;
topologyGraphContainer.innerHTML = ''; // Clear previous SVG content
const nodesData = (topologyData && topologyData.nodes) ? topologyData.nodes : {};
const edgesData = (topologyData && Array.isArray(topologyData.edges)) ? topologyData.edges : [];
const nodeIds = Object.keys(nodesData);
if (nodeIds.length === 0) {
@@ -1214,23 +1535,128 @@
};
});
// Create group for links (drawn first, so they are behind nodes)
// Add arrowhead definition (supports bidirectional arrows on a single line)
const defs = document.createElementNS('http://www.w3.org/2000/svg', 'defs');
const marker = document.createElementNS('http://www.w3.org/2000/svg', 'marker');
marker.setAttribute('id', 'arrowhead');
marker.setAttribute('viewBox', '0 0 10 10');
marker.setAttribute('refX', '10');
marker.setAttribute('refY', '5');
marker.setAttribute('markerWidth', '11');
marker.setAttribute('markerHeight', '11');
marker.setAttribute('orient', 'auto-start-reverse');
// Draw a subtle V-tip (no filled body)
const markerTip = document.createElementNS('http://www.w3.org/2000/svg', 'path');
markerTip.setAttribute('d', 'M 0 0 L 10 5 L 0 10');
markerTip.setAttribute('fill', 'none');
markerTip.setAttribute('stroke', 'var(--exo-light-gray)');
markerTip.setAttribute('stroke-width', '1.6');
markerTip.setAttribute('stroke-linecap', 'round');
markerTip.setAttribute('stroke-linejoin', 'round');
markerTip.setAttribute('stroke-dasharray', 'none');
markerTip.setAttribute('stroke-dashoffset', '0');
markerTip.setAttribute('style', 'animation: none; pointer-events: none;');
marker.appendChild(markerTip);
defs.appendChild(marker);
topologyGraphContainer.appendChild(defs);
// Create groups for links and separate arrow markers (so arrows are not affected by line animations)
const linksGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g');
linksGroup.setAttribute('class', 'links-group');
linksGroup.setAttribute('style', 'pointer-events: none;');
const arrowsGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g');
arrowsGroup.setAttribute('class', 'arrows-group');
arrowsGroup.setAttribute('style', 'pointer-events: none;');
for (let i = 0; i < numNodes; i++) {
for (let j = i + 1; j < numNodes; j++) {
const link = document.createElementNS('http://www.w3.org/2000/svg', 'line');
link.setAttribute('x1', nodesWithPositions[i].x);
link.setAttribute('y1', nodesWithPositions[i].y);
link.setAttribute('x2', nodesWithPositions[j].x);
link.setAttribute('y2', nodesWithPositions[j].y);
link.setAttribute('class', 'graph-link');
linksGroup.appendChild(link);
// Build quick lookup for node positions
const positionById = {};
nodesWithPositions.forEach(n => { positionById[n.id] = { x: n.x, y: n.y }; });
// Group directed edges into undirected pairs to support single line with two arrows
const pairMap = new Map(); // key: "a|b" with a<b, value: { a, b, aToB, bToA }
edgesData.forEach(edge => {
if (!edge || !edge.source || !edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
if (edge.source === edge.target) return;
const a = edge.source < edge.target ? edge.source : edge.target;
const b = edge.source < edge.target ? edge.target : edge.source;
const key = `${a}|${b}`;
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false };
if (edge.source === a && edge.target === b) entry.aToB = true; else entry.bToA = true;
pairMap.set(key, entry);
});
// Draw one line per undirected pair with separate arrow carrier lines
pairMap.forEach(entry => {
const posA = positionById[entry.a];
const posB = positionById[entry.b];
if (!posA || !posB) return;
// Full-length center-to-center lines
const x1 = posA.x;
const y1 = posA.y;
const x2 = posB.x;
const y2 = posB.y;
// Base animated dashed line (no markers)
const baseLine = document.createElementNS('http://www.w3.org/2000/svg', 'line');
baseLine.setAttribute('x1', x1);
baseLine.setAttribute('y1', y1);
baseLine.setAttribute('x2', x2);
baseLine.setAttribute('y2', y2);
baseLine.setAttribute('class', 'graph-link');
linksGroup.appendChild(baseLine);
// Arrowheads centered on the line (tip lies exactly on the line),
// offset along the tangent so opposite directions straddle the center.
const dx = x2 - x1;
const dy = y2 - y1;
const len = Math.hypot(dx, dy) || 1;
const ux = dx / len;
const uy = dy / len;
const mx = (x1 + x2) / 2;
const my = (y1 + y2) / 2;
const tipOffset = 16; // shift arrow tips away from the exact center along the line
const carrier = 2; // short carrier segment length to define orientation
if (entry.aToB) {
// Arrow pointing A -> B: place tip slightly before center along +tangent
const tipX = mx - ux * tipOffset;
const tipY = my - uy * tipOffset;
const sx = tipX - ux * carrier;
const sy = tipY - uy * carrier;
const ex = tipX;
const ey = tipY;
const arrowSeg = document.createElementNS('http://www.w3.org/2000/svg', 'line');
arrowSeg.setAttribute('x1', sx);
arrowSeg.setAttribute('y1', sy);
arrowSeg.setAttribute('x2', ex);
arrowSeg.setAttribute('y2', ey);
arrowSeg.setAttribute('stroke', 'none');
arrowSeg.setAttribute('fill', 'none');
arrowSeg.setAttribute('marker-end', 'url(#arrowhead)');
arrowsGroup.appendChild(arrowSeg);
}
}
topologyGraphContainer.appendChild(linksGroup);
if (entry.bToA) {
// Arrow pointing B -> A: place tip slightly after center along -tangent
const tipX = mx + ux * tipOffset;
const tipY = my + uy * tipOffset;
const sx = tipX + ux * carrier; // start ahead so the segment points toward tip
const sy = tipY + uy * carrier;
const ex = tipX;
const ey = tipY;
const arrowSeg = document.createElementNS('http://www.w3.org/2000/svg', 'line');
arrowSeg.setAttribute('x1', sx);
arrowSeg.setAttribute('y1', sy);
arrowSeg.setAttribute('x2', ex);
arrowSeg.setAttribute('y2', ey);
arrowSeg.setAttribute('stroke', 'none');
arrowSeg.setAttribute('fill', 'none');
arrowSeg.setAttribute('marker-end', 'url(#arrowhead)');
arrowsGroup.appendChild(arrowSeg);
}
});
// Create group for nodes
const nodesGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g');
nodesGroup.setAttribute('class', 'nodes-group');
@@ -1738,7 +2164,10 @@
nodesGroup.appendChild(nodeG);
});
// Draw order: lines at the very back, then nodes, then mid-line arrows on top
topologyGraphContainer.appendChild(linksGroup);
topologyGraphContainer.appendChild(nodesGroup);
topologyGraphContainer.appendChild(arrowsGroup);
}
function showNodeDetails(selectedNodeId, allNodesData) {
@@ -1886,13 +2315,22 @@
throw new Error(`HTTP error! status: ${response.status} ${response.statusText}`);
}
const clusterState = await response.json();
const nodesData = transformClusterStateToTopology(clusterState);
renderNodes(nodesData);
const topologyData = transformClusterStateToTopology(clusterState);
// Build nodeId -> friendly name map
nodeIdToFriendlyName = {};
if (topologyData && topologyData.nodes) {
Object.keys(topologyData.nodes).forEach(nid => {
const n = topologyData.nodes[nid];
const name = (n && (n.friendly_name || (n.system_info && n.system_info.model_id))) || null;
if (name) nodeIdToFriendlyName[nid] = name;
});
}
renderNodes(topologyData);
// If a node was selected, and it still exists, refresh its details
if (currentlySelectedNodeId && nodesData[currentlySelectedNodeId]) {
showNodeDetails(currentlySelectedNodeId, nodesData);
} else if (currentlySelectedNodeId && !nodesData[currentlySelectedNodeId]) {
if (currentlySelectedNodeId && topologyData.nodes[currentlySelectedNodeId]) {
showNodeDetails(currentlySelectedNodeId, topologyData.nodes);
} else if (currentlySelectedNodeId && !topologyData.nodes[currentlySelectedNodeId]) {
// If selected node is gone, close panel and clear selection
nodeDetailPanel.classList.remove('visible');
currentlySelectedNodeId = null;
@@ -1938,8 +2376,9 @@
}
function transformClusterStateToTopology(clusterState) {
const result = {};
if (!clusterState) return result;
const resultNodes = {};
const resultEdges = [];
if (!clusterState) return { nodes: resultNodes, edges: resultEdges };
// Helper: get numeric bytes from various shapes (number | {in_bytes}|{inBytes})
function getBytes(value) {
@@ -1959,18 +2398,23 @@
return fallback;
};
// Process nodes from topology or fallback to node_profiles/nodeProfiles directly
// Helper: detect API placeholders like "unknown" (case-insensitive)
const isUnknown = (value) => {
return typeof value === 'string' && value.trim().toLowerCase() === 'unknown';
};
// Process nodes from topology or fallback to nodeProfiles directly (support both snake_case and camelCase)
let nodesToProcess = {};
if (clusterState.topology && Array.isArray(clusterState.topology.nodes)) {
clusterState.topology.nodes.forEach(node => {
const nid = node.node_id ?? node.nodeId;
const nprof = node.node_profile ?? node.nodeProfile;
const nid = node.nodeId ?? node.node_id;
const nprof = node.nodeProfile ?? node.node_profile;
if (nid && nprof) {
nodesToProcess[nid] = nprof;
}
});
} else if (clusterState.node_profiles || clusterState.nodeProfiles) {
nodesToProcess = clusterState.node_profiles ?? clusterState.nodeProfiles;
} else if (clusterState.nodeProfiles || clusterState.node_profiles) {
nodesToProcess = clusterState.nodeProfiles || clusterState.node_profiles;
}
// Transform each node
@@ -1991,10 +2435,15 @@
memBytesAvailable = getBytes(ramAvailVal);
const memBytesUsed = Math.max(memBytesTotal - memBytesAvailable, 0);
// Extract model information
const modelId = pick(nodeProfile, 'model_id', 'modelId', 'Unknown');
const chipId = pick(nodeProfile, 'chip_id', 'chipId', '');
const friendlyName = pick(nodeProfile, 'friendly_name', 'friendlyName', `${nodeId.substring(0, 8)}...`);
// Extract model information with graceful placeholders while node is loading
const rawModelId = pick(nodeProfile, 'model_id', 'modelId', 'Unknown');
const rawChipId = pick(nodeProfile, 'chip_id', 'chipId', '');
const rawFriendlyName = pick(nodeProfile, 'friendly_name', 'friendlyName', `${nodeId.substring(0, 8)}...`);
// When API has not fully loaded (reports "unknown"), present a nice default
const modelId = isUnknown(rawModelId) ? 'Mac Studio' : rawModelId;
const chipId = isUnknown(rawChipId) ? '' : rawChipId;
const friendlyName = (!rawFriendlyName || isUnknown(rawFriendlyName)) ? 'Mac' : rawFriendlyName;
// Extract network addresses (support snake_case and camelCase)
const addrList = [];
@@ -2039,7 +2488,7 @@
timestamp: new Date().toISOString()
};
result[nodeId] = {
resultNodes[nodeId] = {
mem: memBytesTotal,
addrs: addrList,
last_addr_update: Date.now() / 1000,
@@ -2053,7 +2502,21 @@
};
}
return result;
// Extract directed edges from topology.connections if present (support camelCase)
const connections = clusterState.topology && Array.isArray(clusterState.topology.connections)
? clusterState.topology.connections
: [];
connections.forEach(conn => {
if (!conn) return;
const src = conn.localNodeId ?? conn.local_node_id;
const dst = conn.sendBackNodeId ?? conn.send_back_node_id;
if (!src || !dst) return;
if (!resultNodes[src] || !resultNodes[dst]) return; // only draw edges between known nodes
if (src === dst) return; // skip self loops for now
resultEdges.push({ source: src, target: dst });
});
return { nodes: resultNodes, edges: resultEdges };
}
// --- Conditional Data Handling ---
@@ -2193,11 +2656,12 @@
mi.timestamp = new Date().toISOString();
}
}
renderNodes(mockData);
const mockTopology = { nodes: mockData, edges: [] };
renderNodes(mockTopology);
lastUpdatedElement.textContent = `Last updated: ${new Date().toLocaleTimeString()} (Mock Data)`;
if (currentlySelectedNodeId && mockData[currentlySelectedNodeId]) {
showNodeDetails(currentlySelectedNodeId, mockData);
showNodeDetails(currentlySelectedNodeId, mockTopology.nodes);
} else if (currentlySelectedNodeId && !mockData[currentlySelectedNodeId]) {
nodeDetailPanel.classList.remove('visible');
currentlySelectedNodeId = null;

View File

@@ -51,11 +51,11 @@ dev = [
"ruff>=0.11.13",
]
# dependencies only required for Apple Silicon
[project.optional-dependencies]
darwin = [
"mlx",
]
# mlx[cuda] requires a newer version of mlx. the ideal on linux is: default to mlx[cpu] unless[cuda] specified.
# [project.optional-dependencies]
# cuda = [
# "mlx[cuda]==0.26.3",
# ]
###
# workspace configuration

View File

@@ -3,6 +3,7 @@ import concurrent.futures
import contextlib
import os
import resource
from loguru import logger
from asyncio import AbstractEventLoop
from typing import Any, Callable, Optional, cast
@@ -63,6 +64,9 @@ def mlx_setup(
cache_frac_of_mrwss: float = 0.65, # main workhorse
wired_frac_of_mrwss: float = 0.00, # start with no wiring
) -> None:
if not mx.metal.is_available():
logger.warning("Metal is not available. Skipping MLX memory wired limits setup.")
return
info = mx.metal.device_info()
mrwss = int(info["max_recommended_working_set_size"]) # bytes
memsize = int(info["memory_size"]) # bytes

View File

@@ -1,38 +0,0 @@
import asyncio
import pytest
@pytest.mark.asyncio
async def test_master_api_multiple_response_sequential() -> None:
# TODO
return
messages = [ChatMessage(role="user", content="Hello, who are you?")]
token_count = 0
text: str = ""
async for choice in stream_chatgpt_response(messages):
print(choice, flush=True)
if choice.delta and choice.delta.content:
text += choice.delta.content
token_count += 1
if choice.finish_reason:
break
assert token_count >= 3, f"Expected at least 3 tokens, got {token_count}"
assert len(text) > 0, "Expected non-empty response text"
await asyncio.sleep(0.1)
messages = [ChatMessage(role="user", content="What time is it in France?")]
token_count = 0
text = "" # re-initialize, do not redeclare type
async for choice in stream_chatgpt_response(messages):
print(choice, flush=True)
if choice.delta and choice.delta.content:
text += choice.delta.content
token_count += 1
if choice.finish_reason:
break
assert token_count >= 3, f"Expected at least 3 tokens, got {token_count}"
assert len(text) > 0, "Expected non-empty response text"

View File

@@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from pydantic import ConfigDict, Field, field_validator, field_serializer
from pydantic import ConfigDict, Field, field_serializer, field_validator
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId

View File

@@ -6,7 +6,15 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class DownloadProgressData(CamelCaseModel):
total_bytes: Memory
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
completed_files: int
total_files: int
speed: float
eta_ms: int
files: dict[str, "DownloadProgressData"]
class BaseDownloadProgress(TaggedModel):
node_id: NodeId

View File

@@ -12,9 +12,19 @@ from urllib.parse import urljoin
import aiofiles
import aiofiles.os as aios
import aiohttp
from pydantic import BaseModel, DirectoryPath, Field, PositiveInt, TypeAdapter, ConfigDict
from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
Field,
PositiveInt,
TypeAdapter,
)
from exo.shared.constants import EXO_HOME
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
@@ -40,15 +50,13 @@ class FileListEntry(BaseModel):
class RepoFileDownloadProgress(BaseModel):
"""Progress information for an individual file within a repository download."""
repo_id: str
repo_revision: str
file_path: str
downloaded: int
downloaded_this_session: int
total: int
speed: float # bytes per second
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
@@ -57,40 +65,50 @@ class RepoFileDownloadProgress(BaseModel):
class RepoDownloadProgress(BaseModel):
"""Aggregated download progress information for a repository/shard combination.
This structure captures the overall progress of downloading the files
required to materialise a particular *shard* of a model. It purposely
mirrors the key summary fields emitted by the `RepoProgressEvent` so that
the event payload can be cleanly projected onto the long-lived cluster
state.
"""
repo_id: str
repo_revision: str
shard: ShardMetadata
# progress totals
completed_files: int
total_files: int
downloaded_bytes: int
downloaded_bytes_this_session: int
total_bytes: int
# speed / eta
overall_speed: float # bytes per second
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
# lifecycle status
status: Literal["not_started", "in_progress", "complete"]
# fine-grained file progress keyed by file_path
file_progress: Dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(
frozen = True # allow use as dict keys if desired
frozen = True
)
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
return etag
def map_repo_file_download_progress_to_download_progress_data(repo_file_download_progress: RepoFileDownloadProgress) -> DownloadProgressData:
return DownloadProgressData(
downloaded_bytes=repo_file_download_progress.downloaded,
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
total_bytes=repo_file_download_progress.total,
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
total_files=1,
speed=repo_file_download_progress.speed,
eta_ms=int(repo_file_download_progress.eta.total_seconds() * 1000),
files={},
)
def map_repo_download_progress_to_download_progress_data(repo_download_progress: RepoDownloadProgress) -> DownloadProgressData:
return DownloadProgressData(
total_bytes=repo_download_progress.total_bytes,
downloaded_bytes=repo_download_progress.downloaded_bytes,
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
completed_files=repo_download_progress.completed_files,
total_files=repo_download_progress.total_files,
speed=repo_download_progress.overall_speed,
eta_ms=int(repo_download_progress.overall_eta.total_seconds() * 1000),
files={file_path: map_repo_file_download_progress_to_download_progress_data(file_progress) for file_path, file_progress in repo_download_progress.file_progress.items()},
)
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_HOME / "models" / model_id.replace("/", "--")
@@ -141,13 +159,13 @@ async def seed_models(seed_dir: Union[str, Path]):
if path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir / path.name
if await aios.path.exists(dest_path):
print("Skipping moving model to .cache directory")
logger.info("Skipping moving model to .cache directory")
else:
try:
await aios.rename(str(path), str(dest_path))
except Exception:
print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc()
logger.error(f"Error seeding model {path} to {dest_path}")
logger.error(traceback.format_exc())
async def fetch_file_list_with_cache(
@@ -192,13 +210,9 @@ async def _fetch_file_list(
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_auth_headers()
headers = await get_download_headers()
async with (
aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=30, connect=10, sock_read=30, sock_connect=10
)
) as session,
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
):
if response.status == 200:
@@ -218,6 +232,34 @@ async def _fetch_file_list(
raise Exception(f"Failed to fetch file list: {response.status}")
async def get_download_headers() -> dict[str, str]:
return {**(await get_auth_headers()), "Accept-Encoding": "identity"}
def create_http_session(
auto_decompress: bool = False,
timeout_profile: Literal["short", "long"] = "long",
) -> aiohttp.ClientSession:
if timeout_profile == "short":
total_timeout = 30
connect_timeout = 10
sock_read_timeout = 30
sock_connect_timeout = 10
else:
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 1800
sock_connect_timeout = 60
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,
sock_read=sock_read_timeout,
sock_connect=sock_connect_timeout,
),
)
async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -> str:
hasher = hashlib.sha1() if hash_type == "sha1" else hashlib.sha256()
if hash_type == "sha1":
@@ -237,46 +279,29 @@ async def file_meta(
if redirected_location is None
else f"{get_hf_endpoint()}{redirected_location}"
)
headers = await get_auth_headers()
headers = await get_download_headers()
async with (
aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=1800, connect=60, sock_read=1800, sock_connect=60
)
) as session,
create_http_session(timeout_profile="short") as session,
session.head(url, headers=headers) as r,
):
if r.status == 307:
# Try to extract from X-Linked headers first (common for HF redirects)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
etag = (
r.headers.get("X-Linked-ETag")
or r.headers.get("ETag")
or r.headers.get("Etag")
)
if content_length > 0 and etag is not None:
if (etag[0] == '"' and etag[-1] == '"') or (
etag[0] == "'" and etag[-1] == "'"
):
etag = etag[1:-1]
# On redirect, only trust Hugging Face's x-linked-* headers.
x_linked_size = r.headers.get("x-linked-size")
x_linked_etag = r.headers.get("x-linked-etag")
if x_linked_size and x_linked_etag:
content_length = int(x_linked_size)
etag = trim_etag(x_linked_etag)
return content_length, etag
# If not available, recurse with the redirect
redirected_location = r.headers.get("Location")
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(repo_id, revision, path, redirected_location)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
etag = (
r.headers.get("X-Linked-ETag")
or r.headers.get("ETag")
or r.headers.get("Etag")
)
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
etag = etag[1:-1]
etag = trim_etag(etag)
return content_length, etag
@@ -296,10 +321,10 @@ async def download_file_with_retry(
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
print(
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
)
traceback.print_exc()
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
@@ -326,23 +351,13 @@ async def _download_file(
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
headers = await get_download_headers()
if resume_byte_pos:
headers["Range"] = f"bytes={resume_byte_pos}-"
n_read = resume_byte_pos or 0
async with (
aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=1800, connect=60, sock_read=1800, sock_connect=60
)
) as session,
session.get(
url,
headers=headers,
timeout=aiohttp.ClientTimeout(
total=1800, connect=60, sock_read=1800, sock_connect=60
),
) as r,
create_http_session(timeout_profile="long") as session,
session.get(url, headers=headers) as r,
):
if r.status == 404:
raise FileNotFoundError(f"File not found: {url}")
@@ -364,7 +379,7 @@ async def _download_file(
try:
await aios.remove(partial_path)
except Exception as e:
print(f"Error removing partial file {partial_path}: {e}")
logger.error(f"Error removing partial file {partial_path}: {e}")
raise Exception(
f"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}"
)
@@ -379,11 +394,9 @@ def calculate_repo_progress(
file_progress: Dict[str, RepoFileDownloadProgress],
all_start_time: float,
) -> RepoDownloadProgress:
all_total_bytes = sum(p.total for p in file_progress.values())
all_downloaded_bytes = sum(p.downloaded for p in file_progress.values())
all_downloaded_bytes_this_session = sum(
p.downloaded_this_session for p in file_progress.values()
)
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes = sum((p.downloaded.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes_this_session = sum((p.downloaded_this_session.in_bytes for p in file_progress.values()), 0)
elapsed_time = time.time() - all_start_time
all_speed = (
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
@@ -408,9 +421,9 @@ def calculate_repo_progress(
[p for p in file_progress.values() if p.downloaded == p.total]
),
total_files=len(file_progress),
downloaded_bytes=all_downloaded_bytes,
downloaded_bytes_this_session=all_downloaded_bytes_this_session,
total_bytes=all_total_bytes,
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
downloaded_bytes_this_session=Memory.from_bytes(all_downloaded_bytes_this_session),
total_bytes=Memory.from_bytes(all_total_bytes),
overall_speed=all_speed,
overall_eta=all_eta,
status=status,
@@ -434,8 +447,8 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> List[str]:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
print(f"Error getting weight map for {shard.model_meta.model_id=}")
traceback.print_exc()
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -451,13 +464,11 @@ async def get_downloaded_size(path: Path) -> int:
async def download_progress_for_local_path(
repo_id: str, shard: ShardMetadata, local_path: Path
) -> RepoDownloadProgress:
# Scan local files for accurate progress reporting
file_progress: Dict[str, RepoFileDownloadProgress] = {}
total_files = 0
total_bytes = 0
if await aios.path.isdir(local_path):
# Recursively count files and sizes
for root, _, files in os.walk(local_path):
for f in files:
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
@@ -468,9 +479,9 @@ async def download_progress_for_local_path(
repo_id=repo_id,
repo_revision="local",
file_path=rel_path,
downloaded=size,
downloaded_this_session=0,
total=size,
downloaded=Memory.from_bytes(size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(size),
speed=0,
eta=timedelta(0),
status="complete",
@@ -487,9 +498,9 @@ async def download_progress_for_local_path(
shard=shard,
completed_files=total_files,
total_files=total_files,
downloaded_bytes=total_bytes,
downloaded_bytes_this_session=0,
total_bytes=total_bytes,
downloaded_bytes=Memory.from_bytes(total_bytes),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(total_bytes),
overall_speed=0,
overall_eta=timedelta(0),
status="complete",
@@ -505,11 +516,11 @@ async def download_shard(
allow_patterns: List[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
print(f"Downloading {shard.model_meta.model_id=}")
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_meta.model_id)):
print(f"Using local model path {shard.model_meta.model_id}")
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_meta.model_id), shard, local_path
@@ -525,7 +536,7 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
print(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
@@ -546,8 +557,8 @@ async def download_shard(
else time.time()
)
downloaded_this_session = (
file_progress[file.path].downloaded_this_session
+ (curr_bytes - file_progress[file.path].downloaded)
file_progress[file.path].downloaded_this_session.in_bytes
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
if file.path in file_progress
else curr_bytes
)
@@ -565,9 +576,9 @@ async def download_shard(
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=curr_bytes,
downloaded_this_session=downloaded_this_session,
total=total_bytes,
downloaded=Memory.from_bytes(curr_bytes),
downloaded_this_session=Memory.from_bytes(downloaded_this_session),
total=Memory.from_bytes(total_bytes),
speed=speed,
eta=eta,
status="complete" if curr_bytes == total_bytes else "in_progress",
@@ -590,9 +601,9 @@ async def download_shard(
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=downloaded_bytes,
downloaded_this_session=0,
total=file.size or 0,
downloaded=Memory.from_bytes(downloaded_bytes),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(file.size or 0),
speed=0,
eta=timedelta(0),
status="complete" if downloaded_bytes == file.size else "not_started",

View File

@@ -64,9 +64,9 @@ class ShardDownloader(ABC):
),
completed_files=0,
total_files=0,
downloaded_bytes=0,
downloaded_bytes_this_session=0,
total_bytes=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
@@ -113,9 +113,9 @@ class NoopShardDownloader(ShardDownloader):
),
completed_files=0,
total_files=0,
downloaded_bytes=0,
downloaded_bytes_this_session=0,
total_bytes=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
@@ -131,9 +131,9 @@ class NoopShardDownloader(ShardDownloader):
shard=shard,
completed_files=0,
total_files=0,
downloaded_bytes=0,
downloaded_bytes_this_session=0,
total_bytes=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",

View File

@@ -30,7 +30,6 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
@@ -41,7 +40,6 @@ from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadPending,
DownloadProgressData,
)
from exo.shared.types.worker.ops import (
AssignRunnerOp,
@@ -64,6 +62,9 @@ from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender
from exo.utils.event_buffer import OrderedBuffer
from exo.worker.common import AssignedRunner
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -318,12 +319,7 @@ class Worker:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=Memory.from_bytes(initial_progress.total_bytes),
downloaded_bytes=Memory.from_bytes(
initial_progress.downloaded_bytes
),
),
download_progress=map_repo_download_progress_to_download_progress_data(initial_progress),
)
)
yield assigned_runner.status_update_event()
@@ -377,12 +373,7 @@ class Worker:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=Memory.from_bytes(progress.total_bytes),
downloaded_bytes=Memory.from_bytes(
progress.downloaded_bytes
),
),
download_progress=map_repo_download_progress_to_download_progress_data(progress),
)
)
yield assigned_runner.status_update_event()

View File

@@ -117,7 +117,7 @@ def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus:
download_progress=DownloadOngoing(
node_id=node_id,
download_progress=DownloadProgressData(
total_bytes=Memory.from_bytes(1), downloaded_bytes=Memory.from_bytes(0)
total_bytes=Memory.from_bytes(1), downloaded_bytes=Memory.from_bytes(0), downloaded_bytes_this_session=Memory.from_bytes(0), completed_files=0, total_files=0, speed=0, eta_ms=0, files={}
),
)
)

View File

@@ -7,6 +7,7 @@ import anyio
import psutil
from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
@@ -48,16 +49,14 @@ async def get_memory_profile_async() -> MemoryPerformanceProfile:
vm = psutil.virtual_memory()
sm = psutil.swap_memory()
override_memory_env = os.getenv("OVERRIDE_MEMORY")
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
int(override_memory_env) * 2**30 if override_memory_env else None
Memory.from_mb(int(override_memory_env)).in_bytes if override_memory_env else None
)
return MemoryPerformanceProfile.from_bytes(
ram_total=int(vm.total),
ram_available=int(override_memory)
if override_memory
else int(vm.available),
ram_available=int(override_memory) if override_memory else int(vm.available),
swap_total=int(sm.total),
swap_available=int(sm.free),
)
@@ -99,14 +98,15 @@ async def start_polling_node_metrics(
system_info,
network_interfaces,
mac_friendly_name,
memory_profile,
) = await asyncio.gather(
get_mac_system_info_async(),
get_network_interface_info_async(),
get_mac_friendly_name_async(),
get_memory_profile_async(),
)
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = await get_memory_profile_async()
await callback(
NodePerformanceProfile(
model_id=system_info.model_id,

1047
uv.lock generated
View File

File diff suppressed because it is too large Load Diff