Compare commits

..

2 Commits

Author SHA1 Message Date
Alex Cheema
6ab1cb285b Enhance LaTeX rendering in dashboard markdown
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 14:49:28 +00:00
Antonio Lujano Luna
9c29eb7d48 Add proxy and custom SSL certificate support for corporate networks (#1189)
Support HTTPS_PROXY/HTTP_PROXY environment variables for proxy
configuration and SSL_CERT_FILE for custom CA certificates, enabling use
in corporate environments with SSL inspection.

## Motivation
Users in corporate environments often need to route traffic through HTTP
proxies and use custom CA certificates for SSL inspection. Without this
support, exo cannot download models in these network configurations.

## Changes
- Added `HTTPS_PROXY`/`HTTP_PROXY` environment variable support to
`create_http_session()` in `download_utils.py`
- Added `SSL_CERT_FILE` environment variable support for custom CA
certificate bundles, falling back to certifi's default bundle

## Why It Works
- `aiohttp.ClientSession` natively supports the `proxy` parameter for
routing requests through HTTP proxies
- `ssl.create_default_context(cafile=...)` accepts a custom CA bundle
path, allowing corporate CAs to be trusted
- Using environment variables is consistent with the codebase's existing
configuration patterns (e.g., `EXO_HOME`, `HF_ENDPOINT`)

## Test Plan
### Manual Testing
- Set `HTTPS_PROXY` environment variable and verified model downloads
route through proxy
- Set `SSL_CERT_FILE` to custom CA bundle and verified SSL verification
succeeds with corporate SSL inspection

### Automated Testing
- No automated tests added; this change is configuration-only and does
not alter existing behavior when environment variables are unset
2026-01-18 12:05:50 +00:00
13 changed files with 512 additions and 324 deletions

View File

@@ -53,62 +53,188 @@
marked.use({ renderer });
/**
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
* Also protect code blocks from LaTeX processing
* Unescape HTML entities that marked may have escaped
*/
function unescapeHtmlEntities(text: string): string {
return text
.replace(/&lt;/g, '<')
.replace(/&gt;/g, '>')
.replace(/&amp;/g, '&')
.replace(/&quot;/g, '"')
.replace(/&#39;/g, "'");
}
// Storage for math expressions extracted before markdown processing
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
let mathCounter = 0;
// Use alphanumeric placeholders that won't be interpreted as HTML tags
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
/**
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
*/
function preprocessLaTeX(text: string): string {
// Protect code blocks
// Reset math storage
mathExpressions.clear();
mathCounter = 0;
// Protect code blocks first
const codeBlocks: string[] = [];
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
codeBlocks.push(match);
return `<<CODE_${codeBlocks.length - 1}>>`;
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
});
// Convert \(...\) to $...$
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
// Convert \[...\] to $$...$$
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
// Remove LaTeX document commands
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\begin\{document\}/g, '');
processed = processed.replace(/\\end\{document\}/g, '');
processed = processed.replace(/\\maketitle/g, '');
processed = processed.replace(/\\title\{[^}]*\}/g, '');
processed = processed.replace(/\\author\{[^}]*\}/g, '');
processed = processed.replace(/\\date\{[^}]*\}/g, '');
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
processed = processed.replace(/\\require\{[^}]*\}/g, '');
// Remove unsupported LaTeX commands/environments (tikzpicture, etc.)
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, '[diagram]');
processed = processed.replace(/\\label\{[^}]*\}/g, '');
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
// Convert LaTeX math environments to display math
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*'];
for (const env of mathEnvs) {
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
processed = processed.replace(envRegex, (_, content) => {
// For align environments, wrap content properly for KaTeX
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
}
// Convert LaTeX proof environments to styled blocks
processed = processed.replace(
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
'<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">$1</div></div>'
);
// Convert LaTeX theorem-like environments
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
for (const env of theoremEnvs) {
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
const envName = env.charAt(0).toUpperCase() + env.slice(1);
processed = processed.replace(
envRegex,
`<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">$1</div></div>`
);
}
// Convert LaTeX text formatting commands
processed = processed.replace(/\\emph\{([^}]*)\}/g, '<em>$1</em>');
processed = processed.replace(/\\textit\{([^}]*)\}/g, '<em>$1</em>');
processed = processed.replace(/\\textbf\{([^}]*)\}/g, '<strong>$1</strong>');
processed = processed.replace(/\\texttt\{([^}]*)\}/g, '<code class="inline-code">$1</code>');
processed = processed.replace(/\\underline\{([^}]*)\}/g, '<u>$1</u>');
// Convert \(...\) to placeholder (display: false)
processed = processed.replace(/\\\((.+?)\\\)/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: false });
mathCounter++;
return placeholder;
});
// Convert \[...\] to placeholder (display: true)
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: true });
mathCounter++;
return placeholder;
});
// Extract display math ($$...$$) BEFORE markdown processing
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
mathCounter++;
return placeholder;
});
// Extract inline math ($...$) BEFORE markdown processing
// Skip currency patterns like $5 or $50
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
if (/^\d/.test(content.trim())) {
return match; // Keep as-is for currency
}
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
mathCounter++;
return placeholder;
});
// Restore escaped dollar signs
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
// Restore code blocks
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
return processed;
}
/**
* Render math expressions with KaTeX after HTML is generated
* Render math expressions with KaTeX - restores placeholders with rendered math
*/
function renderMath(html: string): string {
// Render display math ($$...$$)
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
try {
return katex.renderToString(math.trim(), {
displayMode: true,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$$${math}$$</span>`;
}
});
// Replace all math placeholders with rendered KaTeX
for (const [placeholder, { content, displayMode }] of mathExpressions) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
// Render inline math ($...$) but avoid matching currency like $5
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
// Skip if it looks like currency ($ followed by number)
if (/^\d/.test(math.trim())) {
return match;
}
try {
return katex.renderToString(math.trim(), {
displayMode: false,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$${math}$</span>`;
}
});
html = html.replace(regex, () => {
try {
const rendered = katex.renderToString(content, {
displayMode,
throwOnError: false,
output: 'html'
});
if (displayMode) {
return `
<div class="math-display-wrapper">
<div class="math-display-header">
<span class="math-label">LaTeX</span>
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
</svg>
</button>
</div>
<div class="math-display-content">
${rendered}
</div>
</div>
`;
} else {
return `<span class="math-inline">${rendered}</span>`;
}
} catch {
const display = displayMode ? `$$${content}$$` : `$${content}$`;
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
}
});
}
return html;
}
@@ -154,16 +280,50 @@
}
}
async function handleMathCopyClick(event: Event) {
const target = event.currentTarget as HTMLButtonElement;
const encodedSource = target.getAttribute('data-math-source');
if (!encodedSource) return;
const source = decodeURIComponent(encodedSource);
try {
await navigator.clipboard.writeText(source);
// Show copied feedback
const originalHtml = target.innerHTML;
target.innerHTML = `
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M20 6L9 17l-5-5"/>
</svg>
`;
target.classList.add('copied');
setTimeout(() => {
target.innerHTML = originalHtml;
target.classList.remove('copied');
}, 2000);
} catch (error) {
console.error('Failed to copy math:', error);
}
}
function setupCopyButtons() {
if (!containerRef || !browser) return;
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of buttons) {
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of codeButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleCopyClick);
}
}
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
for (const button of mathButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleMathCopyClick);
}
}
}
$effect(() => {
@@ -424,28 +584,263 @@
color: #60a5fa;
}
/* KaTeX math styling */
/* KaTeX math styling - Base */
.markdown-content :global(.katex) {
font-size: 1.1em;
color: oklch(0.9 0 0);
}
.markdown-content :global(.katex-display) {
/* Display math container wrapper */
.markdown-content :global(.math-display-wrapper) {
margin: 1rem 0;
border-radius: 0.5rem;
overflow: hidden;
border: 1px solid rgba(255, 215, 0, 0.15);
background: rgba(0, 0, 0, 0.3);
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover) {
border-color: rgba(255, 215, 0, 0.25);
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
}
/* Display math header - hidden by default, slides in on hover */
.markdown-content :global(.math-display-header) {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.375rem 0.75rem;
background: rgba(255, 215, 0, 0.03);
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
opacity: 0;
max-height: 0;
padding-top: 0;
padding-bottom: 0;
overflow: hidden;
transition:
opacity 0.2s ease,
max-height 0.2s ease,
padding 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
opacity: 1;
max-height: 2.5rem;
padding: 0.375rem 0.75rem;
}
.markdown-content :global(.math-label) {
color: rgba(255, 215, 0, 0.7);
font-size: 0.65rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.1em;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
}
.markdown-content :global(.copy-math-btn) {
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
background: transparent;
border: none;
color: var(--exo-light-gray, #9ca3af);
cursor: pointer;
transition: color 0.2s;
border-radius: 0.25rem;
opacity: 0;
transition:
color 0.2s,
opacity 0.15s ease;
}
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
opacity: 1;
}
.markdown-content :global(.copy-math-btn:hover) {
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(.copy-math-btn.copied) {
color: #22c55e;
}
/* Display math content area */
.markdown-content :global(.math-display-content) {
padding: 1rem 1.25rem;
overflow-x: auto;
overflow-y: hidden;
padding: 0.5rem 0;
}
.markdown-content :global(.katex-display > .katex) {
/* Custom scrollbar for math overflow */
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
height: 6px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
background: rgba(255, 255, 255, 0.05);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
background: rgba(255, 215, 0, 0.2);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
background: rgba(255, 215, 0, 0.35);
}
.markdown-content :global(.math-display-content .katex-display) {
margin: 0;
padding: 0;
}
.markdown-content :global(.math-display-content .katex-display > .katex) {
text-align: center;
}
/* Inline math wrapper */
.markdown-content :global(.math-inline) {
display: inline;
padding: 0 0.125rem;
border-radius: 0.25rem;
transition: background-color 0.15s ease;
}
.markdown-content :global(.math-inline:hover) {
background: rgba(255, 215, 0, 0.05);
}
/* Dark theme KaTeX overrides */
.markdown-content :global(.katex .mord),
.markdown-content :global(.katex .minner),
.markdown-content :global(.katex .mop),
.markdown-content :global(.katex .mbin),
.markdown-content :global(.katex .mrel),
.markdown-content :global(.katex .mpunct) {
color: oklch(0.9 0 0);
}
/* Fraction lines and rules */
.markdown-content :global(.katex .frac-line),
.markdown-content :global(.katex .overline-line),
.markdown-content :global(.katex .underline-line),
.markdown-content :global(.katex .hline),
.markdown-content :global(.katex .rule) {
border-color: oklch(0.85 0 0) !important;
background: oklch(0.85 0 0);
}
/* Square roots and SVG elements */
.markdown-content :global(.katex .sqrt-line) {
border-color: oklch(0.85 0 0) !important;
}
.markdown-content :global(.katex svg) {
fill: oklch(0.85 0 0);
stroke: oklch(0.85 0 0);
}
.markdown-content :global(.katex svg path) {
stroke: oklch(0.85 0 0);
}
/* Delimiters (parentheses, brackets, braces) */
.markdown-content :global(.katex .delimsizing),
.markdown-content :global(.katex .delim-size1),
.markdown-content :global(.katex .delim-size2),
.markdown-content :global(.katex .delim-size3),
.markdown-content :global(.katex .delim-size4),
.markdown-content :global(.katex .mopen),
.markdown-content :global(.katex .mclose) {
color: oklch(0.75 0 0);
}
/* Math error styling */
.markdown-content :global(.math-error) {
display: inline-flex;
align-items: center;
gap: 0.375rem;
color: #f87171;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
background: rgba(248, 113, 113, 0.1);
padding: 0.125rem 0.25rem;
padding: 0.25rem 0.5rem;
border-radius: 0.25rem;
border: 1px solid rgba(248, 113, 113, 0.2);
}
.markdown-content :global(.math-error-icon) {
font-size: 0.875em;
opacity: 0.9;
}
/* LaTeX proof environment */
.markdown-content :global(.latex-proof) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 255, 255, 0.02);
border-left: 3px solid rgba(255, 215, 0, 0.4);
border-radius: 0 0.375rem 0.375rem 0;
}
.markdown-content :global(.latex-proof-header) {
font-weight: 600;
font-style: italic;
color: oklch(0.85 0 0);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-proof-header::after) {
content: '.';
}
.markdown-content :global(.latex-proof-content) {
color: oklch(0.9 0 0);
}
.markdown-content :global(.latex-proof-content p:last-child) {
margin-bottom: 0;
}
/* QED symbol at end of proof */
.markdown-content :global(.latex-proof-content::after) {
content: '∎';
display: block;
text-align: right;
color: oklch(0.7 0 0);
margin-top: 0.5rem;
}
/* LaTeX theorem-like environments */
.markdown-content :global(.latex-theorem) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 215, 0, 0.03);
border: 1px solid rgba(255, 215, 0, 0.15);
border-radius: 0.375rem;
}
.markdown-content :global(.latex-theorem-header) {
font-weight: 700;
color: var(--exo-yellow, #ffd700);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-theorem-header::after) {
content: '.';
}
.markdown-content :global(.latex-theorem-content) {
color: oklch(0.9 0 0);
font-style: italic;
}
.markdown-content :global(.latex-theorem-content p:last-child) {
margin-bottom: 0;
}
</style>

View File

@@ -69,8 +69,6 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
interface RawNodeProfile {

View File

@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -58,8 +58,6 @@ const sidebarVisible = $derived(chatSidebarVisible());
sharding: 'Pipeline' | 'Tensor';
instanceType: InstanceMeta;
minNodes: number;
draftModel: string | null;
numDraftTokens: number;
}
function saveLaunchDefaults(): void {
@@ -68,8 +66,6 @@ const sidebarVisible = $derived(chatSidebarVisible());
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
draftModel: selectedDraftModel,
numDraftTokens: selectedNumDraftTokens,
};
try {
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
@@ -92,36 +88,24 @@ const sidebarVisible = $derived(chatSidebarVisible());
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
}
// Apply draft model if it exists in the available models (check against hugging_face_id)
if (defaults.draftModel && availableModels.some(m => (m as {hugging_face_id?: string}).hugging_face_id === defaults.draftModel)) {
selectedDraftModel = defaults.draftModel;
}
// Apply num draft tokens if valid
if (defaults.numDraftTokens && defaults.numDraftTokens >= 1 && defaults.numDraftTokens <= 10) {
selectedNumDraftTokens = defaults.numDraftTokens;
}
}
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
let selectedMinNodes = $state<number>(1);
let selectedDraftModel = $state<string | null>(null);
let selectedNumDraftTokens = $state<number>(4);
let minNodesInitialized = $state(false);
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
@@ -129,8 +113,6 @@ let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state('');
let isDraftModelDropdownOpen = $state(false);
let draftModelDropdownSearch = $state('');
// Slider dragging state
let isDraggingSlider = $state(false);
@@ -380,39 +362,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
if (!modelId || launchingModelId) return;
launchingModelId = modelId;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
const preview = specificPreview ?? filteredPreview();
let response: Response;
// Use /place_instance endpoint - it handles placement and creation in one step
// This also supports draft_model for speculative decoding
const placePayload = {
model_id: modelId,
sharding: preview?.sharding ?? selectedSharding,
instance_meta: preview?.instance_meta ?? selectedInstanceType,
min_nodes: selectedMinNodes,
draft_model: selectedDraftModel,
num_draft_tokens: selectedDraftModel ? selectedNumDraftTokens : 4,
};
response = await fetch('/place_instance', {
let instanceData: unknown;
if (preview?.instance) {
// Use the instance from the preview
instanceData = preview.instance;
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
);
if (!placementResponse.ok) {
const errorText = await placementResponse.text();
console.error('Failed to get placement:', errorText);
return;
}
instanceData = await placementResponse.json();
}
// POST the instance to create it
const response = await fetch('/instance', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(placePayload)
body: JSON.stringify({ instance: instanceData })
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
const scrollToBottom = () => {
@@ -826,34 +816,30 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
nodeNames: string[];
nodeIds: string[];
nodeCount: number;
draftModel: string | null;
numDraftTokens: number | null;
} {
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
}
// Instance type from tag
let instanceType = 'Unknown';
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
runnerToShard?: Record<string, unknown>;
};
draftModel?: string;
numDraftTokens?: number;
}
};
// Sharding strategy from first shard
let sharding = 'Unknown';
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
@@ -864,7 +850,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
}
// Node names from topology
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
const nodeIds = Object.keys(nodeToRunner);
@@ -872,12 +858,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
});
// Draft model for speculative decoding
const draftModel = inst.draftModel ?? null;
const numDraftTokens = inst.numDraftTokens ?? null;
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
}
function formatLastUpdate(): string {
@@ -1363,9 +1345,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -1699,80 +1678,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{/each}
</div>
</div>
<!-- Draft Model (Speculative Decoding) -->
<div>
<div class="text-xs text-white/70 font-mono mb-2">Draft Model (Speculative):</div>
<div class="relative">
<button
onclick={() => { isDraftModelDropdownOpen = !isDraftModelDropdownOpen; draftModelDropdownSearch = ''; }}
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {selectedDraftModel ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="truncate">{selectedDraftModel ? selectedDraftModel.split('/').pop() : 'None'}</span>
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftModelDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
</svg>
</button>
{#if isDraftModelDropdownOpen}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="fixed inset-0 z-40"
onclick={() => isDraftModelDropdownOpen = false}
onkeydown={(e) => e.key === 'Escape' && (isDraftModelDropdownOpen = false)}
></div>
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
<div class="p-2 border-b border-exo-medium-gray/30">
<input
type="text"
bind:value={draftModelDropdownSearch}
placeholder="Search models..."
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-exo-yellow/50"
/>
</div>
<div class="overflow-y-auto max-h-36">
<!-- None option -->
<button
onclick={() => { selectedDraftModel = null; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {selectedDraftModel === null ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
>
<span>None</span>
</button>
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftModelDropdownSearch.toLowerCase()) && m.id !== selectedModelId) as model}
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
{@const modelHfId = model.hugging_face_id ?? model.id}
<button
onclick={() => { selectedDraftModel = modelHfId; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedDraftModel === modelHfId ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex-shrink-0 text-xs text-white/50">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
{/each}
</div>
</div>
{/if}
</div>
</div>
<!-- Draft Tokens (only show when draft model selected) -->
{#if selectedDraftModel}
<div class="flex items-center gap-2 mt-2">
<span class="text-xs text-white/50 font-mono">Tokens:</span>
<div class="flex items-center gap-1">
{#each [2, 3, 4, 5, 6] as n}
<button
onclick={() => { selectedNumDraftTokens = n; saveLaunchDefaults(); }}
class="w-6 h-6 text-xs font-mono rounded transition-all {selectedNumDraftTokens === n ? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/50' : 'text-white/50 hover:text-white/80 border border-transparent'}"
>{n}</button>
{/each}
</div>
</div>
{/if}
</div>
<!-- Selected Model Preview -->
<div class="space-y-3">
{#if models.length === 0}

View File

@@ -200,8 +200,6 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)

View File

@@ -151,8 +151,6 @@ def place_instance(
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -166,8 +164,6 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -161,8 +161,6 @@ class ChatCompletionTaskParams(BaseModel):
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
num_draft_tokens: int = 3
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
@@ -174,8 +172,6 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod

View File

@@ -2,7 +2,7 @@ from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -25,8 +25,6 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
class CreateInstance(BaseCommand):

View File

@@ -3,7 +3,6 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -20,8 +19,6 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -245,12 +245,15 @@ def create_http_session(
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(cafile=certifi.where())
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,

View File

@@ -119,8 +119,6 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -137,6 +135,8 @@ def mlx_generate(
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -149,31 +149,19 @@ def mlx_generate(
)
max_tokens = task.max_tokens or MAX_TOKENS
# Build kwargs for stream_generate, conditionally adding draft model params
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": prompt,
"max_tokens": max_tokens,
"sampler": sampler,
"logits_processors": logits_processors,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Add speculative decoding parameters if draft model is provided
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
# as speculative decoding requires cache trimming capabilities
if draft_model is not None:
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
# Only use custom cache for non-speculative generation
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info(out.text)
stats: GenerationStats | None = None

View File

@@ -258,27 +258,6 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: str) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,

View File

@@ -3,8 +3,7 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -36,7 +35,6 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -59,7 +57,6 @@ def plan(
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _draft_model_needs_download(runners, download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
)
@@ -131,57 +128,6 @@ def _model_needs_download(
)
def _draft_model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
"""Check if draft model needs download (for speculative decoding).
Only rank 0 needs the draft model, and only after the main model is loaded.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
shard = runner.bound_instance.bound_shard
# Only check when runner is loaded and ready for warmup
if not isinstance(runner.status, RunnerLoaded):
continue
# Only rank 0 loads the draft model
if shard.device_rank != 0:
continue
# Check if instance has a draft model configured
draft_model_id = instance.draft_model
if draft_model_id is None:
continue
# Check if draft model needs download
if draft_model_id not in download_status or not isinstance(
download_status[draft_model_id], (DownloadOngoing, DownloadCompleted)
):
# Create minimal shard metadata for draft model download
draft_shard = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=draft_model_id,
pretty_name=str(draft_model_id),
storage_size=Memory.from_bytes(0), # Unknown, will be determined during download
n_layers=1, # Placeholder
hidden_size=1, # Placeholder
supports_tensor=False,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
return DownloadModel(
instance_id=instance.instance_id,
shard_metadata=draft_shard,
)
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],

View File

@@ -56,7 +56,6 @@ from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_draft_model,
load_mlx_items,
mlx_force_oom,
)
@@ -111,7 +110,6 @@ def main(
model = None
tokenizer = None
group = None
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -180,20 +178,11 @@ def main(
)
)
# Load draft model for speculative decoding (rank 0 only)
if (
instance.draft_model is not None
and shard_metadata.device_rank == 0
):
logger.info(f"Loading draft model: {instance.draft_model}")
draft_model = cast(
Model, load_draft_model(str(instance.draft_model))
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=cast(Model, model),
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
@@ -223,13 +212,11 @@ def main(
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses (draft_model loaded at warmup if configured)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=cast(Model, model),
tokenizer=tokenizer,
task=task_params,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
)
# GPT-OSS specific parsing to match other model formats.
@@ -278,7 +265,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group, draft_model
del model, tokenizer, group
mx.clear_cache()
import gc