mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-15 17:39:45 -05:00
Compare commits
4 Commits
model-card
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e6567a802 | ||
|
|
a735dad667 | ||
|
|
aaf4e36bc3 | ||
|
|
3e623ccf0d |
@@ -60,12 +60,39 @@
|
||||
return models;
|
||||
});
|
||||
|
||||
// Auto-select the first available model if none is selected
|
||||
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
|
||||
let previousModelIds: Set<string> = new Set();
|
||||
|
||||
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
|
||||
$effect(() => {
|
||||
const models = availableModels();
|
||||
if (models.length > 0 && !currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
const currentModelIds = new Set(models.map(m => m.id));
|
||||
|
||||
if (models.length > 0) {
|
||||
// Find newly added models (in current but not in previous)
|
||||
const newModels = models.filter(m => !previousModelIds.has(m.id));
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some(m => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
|
||||
// Update previous model IDs for next comparison
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
|
||||
@@ -400,10 +400,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Auto-select the launched model only if no model is currently selected
|
||||
if (!selectedChatModel()) {
|
||||
setSelectedChatModel(modelId);
|
||||
}
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
@@ -763,6 +761,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
async function deleteInstance(instanceId: string) {
|
||||
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
|
||||
|
||||
// Get the model ID of the instance being deleted before we delete it
|
||||
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
|
||||
const wasSelected = selectedChatModel() === deletedInstanceModelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/instance/${instanceId}`, {
|
||||
method: 'DELETE',
|
||||
@@ -771,6 +773,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to delete instance:', response.status);
|
||||
} else if (wasSelected) {
|
||||
// If we deleted the currently selected model, switch to another available model
|
||||
// Find another instance that isn't the one we just deleted
|
||||
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
|
||||
if (remainingInstances.length > 0) {
|
||||
// Select the last instance (most recently added, since objects preserve insertion order)
|
||||
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting instance:', error);
|
||||
|
||||
2
justfile
2
justfile
@@ -1,3 +1,5 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
|
||||
|
||||
@@ -13,12 +13,6 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
@@ -67,8 +61,6 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
@@ -381,35 +373,8 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
@@ -417,16 +382,10 @@ class API:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
with recv as token_chunks:
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -442,11 +401,11 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
@@ -458,7 +417,7 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
@@ -466,7 +425,7 @@ class API:
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -495,7 +454,7 @@ class API:
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
@@ -503,7 +462,7 @@ class API:
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -544,8 +503,6 @@ class API:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -562,17 +519,16 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
self._generate_chat_stream(command.command_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
@@ -589,10 +545,7 @@ class API:
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
|
||||
@@ -425,15 +425,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-20b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
|
||||
@@ -20,6 +20,7 @@ except ImportError:
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
@@ -365,6 +366,8 @@ def apply_chat_template(
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -396,6 +399,11 @@ def make_kv_cache(
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
@@ -153,11 +162,19 @@ def main(
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
):
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
@@ -207,6 +224,43 @@ def main(
|
||||
break
|
||||
|
||||
|
||||
@cache
|
||||
def get_gpt_oss_encoding():
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return encoding
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
for response in responses:
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield response.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
break
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import http.client
|
||||
import time
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
@@ -6,6 +7,8 @@ from loguru import logger
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
BAD_STATUSLINE_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
@@ -15,8 +18,9 @@ async def check_reachability(
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
# TODO: use an async http client
|
||||
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
@@ -32,7 +36,16 @@ async def check_reachability(
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.HTTPException:
|
||||
except http.client.BadStatusLine:
|
||||
if attempt >= BAD_STATUSLINE_ATTEMPTS:
|
||||
logger.warning(
|
||||
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
|
||||
)
|
||||
return None
|
||||
time.sleep(1)
|
||||
return _fetch_remote_node_id(attempt=attempt + 1)
|
||||
except http.client.HTTPException as e:
|
||||
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
Reference in New Issue
Block a user