mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 23:06:23 -05:00
Compare commits
20 Commits
alexcheema
...
support-ml
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8748eade1 | ||
|
|
83af8c63fa | ||
|
|
eccc6298d1 | ||
|
|
c8997217cf | ||
|
|
490d2e46ba | ||
|
|
facf2d4d03 | ||
|
|
a962a28afc | ||
|
|
db79c350c1 | ||
|
|
d6301ed593 | ||
|
|
6d1ca6689b | ||
|
|
c01b6fff21 | ||
|
|
8392e78afe | ||
|
|
86735ece78 | ||
|
|
2759e92334 | ||
|
|
131fb141a6 | ||
|
|
2d8bfc2e3c | ||
|
|
042999f728 | ||
|
|
b61dc2eb35 | ||
|
|
36a7115b6f | ||
|
|
0b7d88b43b |
27
.github/workflows/pipeline.yml
vendored
27
.github/workflows/pipeline.yml
vendored
@@ -8,33 +8,6 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
typecheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
@@ -5,21 +5,21 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
|
||||
@@ -126,11 +126,37 @@ final class ExoProcessController: ObservableObject {
|
||||
return
|
||||
}
|
||||
process.terminationHandler = nil
|
||||
if process.isRunning {
|
||||
process.terminate()
|
||||
}
|
||||
self.process = nil
|
||||
status = .stopped
|
||||
|
||||
guard process.isRunning else {
|
||||
self.process = nil
|
||||
return
|
||||
}
|
||||
|
||||
let proc = process
|
||||
self.process = nil
|
||||
|
||||
Task.detached {
|
||||
proc.interrupt()
|
||||
|
||||
for _ in 0..<50 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
proc.terminate()
|
||||
}
|
||||
|
||||
for _ in 0..<30 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
kill(proc.processIdentifier, SIGKILL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func restart() {
|
||||
|
||||
7
bench/bench.toml
Normal file
7
bench/bench.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
# Canary benchmark manifest
|
||||
#
|
||||
# Lists the suite files to include. Each file defines benchmarks
|
||||
# with shared constraints, topology, and default args.
|
||||
include = [
|
||||
"single-m3-ultra.toml",
|
||||
]
|
||||
@@ -288,6 +288,151 @@ def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
@@ -535,6 +680,11 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -569,13 +719,16 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and args.settle_timeout > 0:
|
||||
if not selected and settle_deadline:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + args.settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
@@ -607,6 +760,16 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
189
bench/single-m3-ultra.toml
Normal file
189
bench/single-m3-ultra.toml
Normal file
@@ -0,0 +1,189 @@
|
||||
# Single-node M3 Ultra benchmarks
|
||||
#
|
||||
# Shared constraints applied to ALL benchmarks in this file.
|
||||
constraints = [
|
||||
"All(MacOsBuild(=25D125))",
|
||||
"Hosts(=1)",
|
||||
"All(Chip(m3_ultra))",
|
||||
"All(GpuCores(=80))",
|
||||
]
|
||||
|
||||
[topology]
|
||||
type = "none"
|
||||
|
||||
# Default args merged into each benchmark's args (benchmark-level args win).
|
||||
[defaults]
|
||||
pp = [512, 2048, 8192, 16384]
|
||||
tg = 128
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-3bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-6bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
@@ -265,6 +265,7 @@
|
||||
|
||||
function handleSubmit() {
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
|
||||
if (isEditOnlyWithoutImage) return;
|
||||
|
||||
const content = message.trim();
|
||||
const files = [...uploadedFiles];
|
||||
@@ -289,7 +290,11 @@
|
||||
if (imageFile.preview) {
|
||||
editImage(content, imageFile.preview);
|
||||
}
|
||||
} else if (isImageModel() && content) {
|
||||
} else if (
|
||||
currentModel &&
|
||||
modelSupportsTextToImage(currentModel) &&
|
||||
content
|
||||
) {
|
||||
// Use image generation for text-to-image models
|
||||
generateImage(content);
|
||||
} else {
|
||||
|
||||
@@ -225,6 +225,7 @@
|
||||
}
|
||||
|
||||
function handleDeleteClick(messageId: string) {
|
||||
if (loading) return;
|
||||
deleteConfirmId = messageId;
|
||||
}
|
||||
|
||||
@@ -255,7 +256,7 @@
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
{#each messageList as message, i (message.id)}
|
||||
<div
|
||||
class="group flex {message.role === 'user'
|
||||
? 'justify-end'
|
||||
@@ -317,9 +318,11 @@
|
||||
<!-- Delete confirmation -->
|
||||
<div class="bg-red-500/10 border border-red-500/30 rounded-lg p-3">
|
||||
<p class="text-xs text-red-400 mb-3">
|
||||
Delete this message{message.role === "user"
|
||||
? " and all responses after it"
|
||||
: ""}?
|
||||
{#if i === messageList.length - 1}
|
||||
Delete this message?
|
||||
{:else}
|
||||
Delete this message and all messages after it?
|
||||
{/if}
|
||||
</p>
|
||||
<div class="flex gap-2 justify-end">
|
||||
<button
|
||||
@@ -751,8 +754,13 @@
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
onclick={() => handleDeleteClick(message.id)}
|
||||
class="p-1.5 text-exo-light-gray hover:text-red-400 transition-colors rounded hover:bg-red-500/10 cursor-pointer"
|
||||
title="Delete message"
|
||||
disabled={loading}
|
||||
class="p-1.5 transition-colors rounded {loading
|
||||
? 'text-exo-light-gray/30 cursor-not-allowed'
|
||||
: 'text-exo-light-gray hover:text-red-400 hover:bg-red-500/10 cursor-pointer'}"
|
||||
title={loading
|
||||
? "Cannot delete while generating"
|
||||
: "Delete message"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -132,7 +132,7 @@ markers = [
|
||||
env = [
|
||||
"EXO_TESTS=1"
|
||||
]
|
||||
addopts = "-m 'not slow'"
|
||||
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
|
||||
filterwarnings = [
|
||||
"ignore:builtin type Swig:DeprecationWarning",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel.
|
||||
# Preserve passthru so mkVirtualEnv can resolve dependency groups.
|
||||
# Copy .pyi stub + py.typed marker so basedpyright can find the types.
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
@@ -22,6 +24,12 @@
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
passthru = prev.exo-pyo3-bindings.passthru or { };
|
||||
postInstall = ''
|
||||
local siteDir=$out/${final.python.sitePackages}/exo_pyo3_bindings
|
||||
cp ${inputs.self}/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi $siteDir/
|
||||
touch $siteDir/py.typed
|
||||
'';
|
||||
};
|
||||
};
|
||||
|
||||
@@ -29,17 +37,32 @@
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
} // lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
# Use our pure Nix-built MLX with Metal support (macOS only)
|
||||
mlx = self'.packages.mlx;
|
||||
};
|
||||
|
||||
# Additional overlay for Linux-specific fixes (type checking env).
|
||||
# Native wheels have shared lib dependencies we don't need at type-check time.
|
||||
linuxOverlay = final: prev:
|
||||
let
|
||||
ignoreMissing = drv: drv.overrideAttrs { autoPatchelfIgnoreMissingDeps = [ "*" ]; };
|
||||
nvidiaPackages = lib.filterAttrs (name: _: lib.hasPrefix "nvidia-" name) prev;
|
||||
in
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
);
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
@@ -48,6 +71,7 @@
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
@@ -118,6 +142,21 @@
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
touch $out
|
||||
'';
|
||||
|
||||
# Hermetic basedpyright type checking
|
||||
typecheck = pkgs.runCommand "typecheck"
|
||||
{
|
||||
nativeBuildInputs = [
|
||||
testVenv
|
||||
pkgs.basedpyright
|
||||
];
|
||||
}
|
||||
''
|
||||
cd ${inputs.self}
|
||||
export HOME=$TMPDIR
|
||||
basedpyright --pythonpath ${testVenv}/bin/python
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ from exo.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
@@ -63,6 +64,9 @@ class DownloadCoordinator:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
|
||||
def _model_dir(self, model_id: ModelId) -> str:
|
||||
return str(EXO_MODELS_DIR / model_id.normalize())
|
||||
|
||||
async def _download_progress_callback(
|
||||
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
@@ -74,6 +78,7 @@ class DownloadCoordinator:
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
@@ -93,6 +98,7 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
@@ -170,7 +176,11 @@ class DownloadCoordinator:
|
||||
return
|
||||
|
||||
# Emit pending status
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
|
||||
|
||||
@@ -184,6 +194,7 @@ class DownloadCoordinator:
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
@@ -206,6 +217,7 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
@@ -219,6 +231,7 @@ class DownloadCoordinator:
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=str(e),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(
|
||||
@@ -253,6 +266,7 @@ class DownloadCoordinator:
|
||||
pending = DownloadPending(
|
||||
shard_metadata=current_status.shard_metadata,
|
||||
node_id=self.node_id,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=pending)
|
||||
@@ -295,11 +309,18 @@ class DownloadCoordinator:
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
@@ -308,6 +329,9 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -136,6 +136,8 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -147,8 +149,6 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
|
||||
@@ -17,6 +17,7 @@ from exo.shared.types.api import (
|
||||
LogprobsContentItem,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
@@ -125,6 +126,8 @@ async def generate_chat_stream(
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
@@ -138,6 +141,8 @@ async def generate_chat_stream(
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
@@ -161,12 +166,15 @@ async def generate_chat_stream(
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
@@ -176,7 +184,9 @@ async def generate_chat_stream(
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ChatCompletionResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
@@ -184,6 +194,7 @@ async def collect_chat_response(
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
@@ -193,6 +204,8 @@ async def collect_chat_response(
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
@@ -223,7 +236,7 @@ async def collect_chat_response(
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
yield ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
@@ -241,4 +254,6 @@ async def collect_chat_response(
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
usage=last_usage,
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.api import FinishReason, Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
@@ -161,12 +161,14 @@ async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -174,6 +176,8 @@ async def collect_claude_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
tool_use_blocks.append(
|
||||
@@ -183,12 +187,10 @@ async def collect_claude_response(
|
||||
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -208,11 +210,11 @@ async def collect_claude_response(
|
||||
if not content:
|
||||
content.append(ClaudeTextBlock(text=""))
|
||||
|
||||
# Use actual usage data from stats if available
|
||||
input_tokens = last_stats.prompt_tokens if last_stats else 0
|
||||
output_tokens = last_stats.generation_tokens if last_stats else 0
|
||||
# Use actual usage data if available
|
||||
input_tokens = last_usage.prompt_tokens if last_usage else 0
|
||||
output_tokens = last_usage.completion_tokens if last_usage else 0
|
||||
|
||||
return ClaudeMessagesResponse(
|
||||
yield ClaudeMessagesResponse(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=content,
|
||||
@@ -221,7 +223,8 @@ async def collect_claude_response(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
@@ -249,7 +252,7 @@ async def generate_claude_stream(
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -257,8 +260,9 @@ async def generate_claude_stream(
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
|
||||
# Emit tool_use content blocks
|
||||
@@ -290,7 +294,6 @@ async def generate_claude_stream(
|
||||
continue
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
@@ -302,9 +305,9 @@ async def generate_claude_stream(
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
# Use actual token count from usage if available
|
||||
if last_usage is not None:
|
||||
output_tokens = last_usage.completion_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
|
||||
from itertools import count
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
@@ -121,13 +122,15 @@ async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ResponsesResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -135,32 +138,32 @@ async def collect_responses_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
function_call_items.append(
|
||||
ResponseFunctionCallItem(
|
||||
id=f"fc_{tool.id}",
|
||||
call_id=f"call_{tool.id}",
|
||||
id=tool.id,
|
||||
call_id=tool.id,
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Create usage from stats if available
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
if last_usage is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
@@ -172,14 +175,15 @@ async def collect_responses_response(
|
||||
]
|
||||
output.extend(function_call_items)
|
||||
|
||||
return ResponsesResponse(
|
||||
yield ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=output,
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
@@ -235,15 +239,16 @@ async def generate_responses_stream(
|
||||
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
for tool in chunk.tool_calls:
|
||||
fc_id = f"fc_{tool.id}"
|
||||
call_id = f"call_{tool.id}"
|
||||
@@ -302,7 +307,6 @@ async def generate_responses_stream(
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
@@ -346,13 +350,13 @@ async def generate_responses_stream(
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from stats if available
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
if last_usage is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
|
||||
@@ -125,6 +125,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -142,7 +143,12 @@ from exo.shared.types.openai_responses import (
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxDevice,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
@@ -309,6 +315,7 @@ class API:
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
mlx_device=payload.mlx_device,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
@@ -349,6 +356,7 @@ class API:
|
||||
sharding: Sharding = Sharding.Pipeline,
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
mlx_device: MlxDevice = MlxDevice.Auto,
|
||||
) -> Instance:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
|
||||
@@ -359,6 +367,7 @@ class API:
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
mlx_device=mlx_device,
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
@@ -540,16 +549,14 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._text_generation_queues:
|
||||
del self._text_generation_queues[command_id]
|
||||
|
||||
@@ -644,11 +651,14 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionRequest
|
||||
@@ -664,8 +674,7 @@ class API:
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_text_generation_with_stats(command.command_id)
|
||||
return response
|
||||
return await self._collect_text_generation_with_stats(command.command_id)
|
||||
|
||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||
"""Validate a text model exists and return the resolved model ID.
|
||||
@@ -883,6 +892,11 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -964,6 +978,11 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -1221,12 +1240,15 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def openai_responses(
|
||||
self, payload: ResponsesRequest
|
||||
@@ -1254,11 +1276,15 @@ class API:
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
|
||||
@@ -24,6 +24,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
@@ -39,6 +40,7 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
@@ -279,7 +281,7 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
@@ -299,7 +301,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -309,7 +311,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -319,6 +321,18 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -327,10 +341,9 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
# rate limit to 1000 at a time
|
||||
|
||||
@@ -22,9 +22,15 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
@@ -153,6 +159,7 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
mlx_device=command.mlx_device,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -167,6 +174,7 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
mlx_device=command.mlx_device,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
@@ -186,6 +194,7 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -201,6 +210,18 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -4,7 +4,11 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
|
||||
from exo.master.adapters.claude import (
|
||||
ClaudeMessagesResponse,
|
||||
collect_claude_response,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId, ModelId
|
||||
@@ -17,6 +21,18 @@ async def _chunks_to_stream(
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _collect_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Helper to consume the async generator and parse the JSON response."""
|
||||
parts: list[str] = []
|
||||
async for part in collect_claude_response(command_id, model, chunk_stream):
|
||||
parts.append(part)
|
||||
return ClaudeMessagesResponse.model_validate_json("".join(parts))
|
||||
|
||||
|
||||
MODEL = ModelId("test-model")
|
||||
COMMAND_ID = CommandId("cmd_test123")
|
||||
|
||||
@@ -47,7 +63,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -77,7 +93,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -102,7 +118,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -116,7 +132,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
|
||||
async def test_no_content_produces_empty_text_block(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
assert len(response.content) == 1
|
||||
|
||||
@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
|
||||
target_instances = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 0
|
||||
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
|
||||
@@ -218,11 +218,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
# Clean up all granular node mappings
|
||||
node_identities = {
|
||||
key: value
|
||||
for key, value in state.node_identities.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_memory = {
|
||||
key: value for key, value in state.node_memory.items() if key != event.node_id
|
||||
}
|
||||
@@ -263,7 +258,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
"node_identities": node_identities,
|
||||
"node_memory": node_memory,
|
||||
"node_disk": node_disk,
|
||||
"node_system": node_system,
|
||||
|
||||
@@ -3,13 +3,17 @@ from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxDevice,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -227,13 +231,7 @@ class PlaceInstanceParams(BaseModel):
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
def use_default(cls, v: object):
|
||||
if not v or not isinstance(v, (Sharding, InstanceMeta)):
|
||||
raise PydanticUseDefault()
|
||||
return v
|
||||
mlx_device: MlxDevice = MlxDevice.Auto
|
||||
|
||||
|
||||
class CreateInstanceParams(BaseModel):
|
||||
|
||||
@@ -8,7 +8,12 @@ from exo.shared.types.api import (
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxDevice,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -38,6 +43,7 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
mlx_device: MlxDevice = MlxDevice.Auto
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
@@ -48,6 +54,10 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -89,6 +99,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -87,6 +93,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -26,6 +26,7 @@ class DownloadProgressData(CamelCaseModel):
|
||||
class BaseDownloadProgress(TaggedModel):
|
||||
node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
model_directory: str = ""
|
||||
|
||||
|
||||
class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
@@ -16,9 +16,16 @@ class InstanceMeta(str, Enum):
|
||||
MlxJaccl = "MlxJaccl"
|
||||
|
||||
|
||||
class MlxDevice(str, Enum):
|
||||
Auto = "Auto"
|
||||
Cpu = "Cpu"
|
||||
Gpu = "Gpu"
|
||||
|
||||
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
mlx_device: MlxDevice = MlxDevice.Auto
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -62,6 +62,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import sys
|
||||
|
||||
|
||||
def print_startup_banner(port: int) -> None:
|
||||
"""Print a prominent startup banner with API endpoint information."""
|
||||
dashboard_url = f"http://localhost:{port}"
|
||||
banner = f"""
|
||||
╔═══════════════════════════════════════════════════════════════════════╗
|
||||
@@ -27,4 +29,4 @@ def print_startup_banner(port: int) -> None:
|
||||
|
||||
"""
|
||||
|
||||
print(banner)
|
||||
print(banner, file=sys.stderr)
|
||||
|
||||
@@ -125,7 +125,9 @@ class MpSender[T]:
|
||||
self._state.buffer.put(item, block=True)
|
||||
|
||||
async def send_async(self, item: T) -> None:
|
||||
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
|
||||
await to_thread.run_sync(
|
||||
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -15,8 +15,3 @@ DEFAULT_TOP_LOGPROBS: int = 5
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
# Multi-Token Prediction (MTP) configuration for DeepSeek V3
|
||||
# MTP enables speculative decoding using the model's built-in draft layer
|
||||
MTP_ENABLED: bool = True # Feature flag to enable/disable MTP
|
||||
MTP_NUM_DRAFT_TOKENS: int = 1 # Number of tokens to draft (vLLM reports k=1 is optimal)
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generator, cast, get_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
@@ -41,8 +38,6 @@ from exo.worker.engines.mlx.constants import (
|
||||
KV_BITS,
|
||||
KV_GROUP_SIZE,
|
||||
MAX_TOKENS,
|
||||
MTP_ENABLED,
|
||||
MTP_NUM_DRAFT_TOKENS,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -197,11 +192,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def _has_mtp_module(model: Model) -> bool:
|
||||
"""Check if the model has an attached MTP module."""
|
||||
return hasattr(model, "mtp_module") and model.mtp_module is not None # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs: mx.array,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -307,24 +297,6 @@ def mlx_generate(
|
||||
top_k=task.top_k if task.top_k is not None else 0,
|
||||
)
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
|
||||
# Check if we should use MTP speculative decoding
|
||||
use_mtp = MTP_ENABLED and _has_mtp_module(model)
|
||||
|
||||
if use_mtp:
|
||||
logger.info("Using MTP speculative decoding")
|
||||
yield from _mlx_generate_with_mtp(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
)
|
||||
return
|
||||
|
||||
# Normalize stop sequences to a list
|
||||
stop_sequences: list[str] = (
|
||||
([task.stop] if isinstance(task.stop, str) else task.stop)
|
||||
@@ -334,7 +306,7 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Ready to prefill")
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
@@ -349,6 +321,7 @@ def mlx_generate(
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-2:]
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
accumulated_text = ""
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
@@ -420,10 +393,11 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
total_prompt_tokens = len(all_prompt_tokens)
|
||||
usage = Usage(
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
total_tokens=total_prompt_tokens + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
@@ -497,66 +471,3 @@ def mlx_generate(
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
|
||||
def _mlx_generate_with_mtp(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
|
||||
prompt_cache: KVCacheType,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""MTP speculative decoding generation path.
|
||||
|
||||
Uses the model's attached MTP module for speculative decoding,
|
||||
which can provide 1.5-2x speedup with ~81% acceptance rate.
|
||||
"""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
mtp_module = cast("MTPModule", model.mtp_module)
|
||||
|
||||
for out in mtp_speculative_generate(
|
||||
model=model,
|
||||
mtp_module=mtp_module,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=cast(list[Any], prompt_cache),
|
||||
num_draft_tokens=MTP_NUM_DRAFT_TOKENS,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE if KV_GROUP_SIZE is not None else 64,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(f"{out.text} (from_draft={out.from_draft})")
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Multi-Token Prediction (MTP) module for DeepSeek V3 speculative decoding."""
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
__all__ = ["MTPModule", "mtp_speculative_generate"]
|
||||
@@ -1,207 +0,0 @@
|
||||
"""MTP Module for DeepSeek V3 Multi-Token Prediction.
|
||||
|
||||
The MTP architecture predicts one additional token ahead using:
|
||||
1. hnorm - RMSNorm for hidden state normalization
|
||||
2. enorm - RMSNorm for embedding normalization
|
||||
3. eh_proj - Linear(2*hidden_size -> hidden_size) projection
|
||||
4. transformer_block - Single decoder layer (attention + MLP)
|
||||
5. Shared embedding/lm_head from main model
|
||||
|
||||
Forward pass:
|
||||
h_norm = hnorm(hidden_state)
|
||||
e_norm = enorm(embed(token))
|
||||
projected = eh_proj(concat([h_norm, e_norm]))
|
||||
new_hidden = transformer_block(projected)
|
||||
logits = lm_head(output_norm(new_hidden))
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import (
|
||||
DeepseekV3Attention,
|
||||
DeepseekV3MLP,
|
||||
ModelArgs,
|
||||
)
|
||||
|
||||
MTP_LAYER_INDEX = 61
|
||||
|
||||
|
||||
class MTPModule(nn.Module):
|
||||
"""Multi-Token Prediction module for DeepSeek V3.
|
||||
|
||||
This module is initialized from the layer 61 weights that are normally
|
||||
discarded during model loading. It enables speculative decoding by
|
||||
predicting one token ahead using the hidden state from the main model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
shared_embedding: nn.Embedding,
|
||||
shared_lm_head: nn.Linear,
|
||||
output_norm: nn.RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# MTP-specific normalization layers
|
||||
self.hnorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.enorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Projection: concatenated [hidden, embedding] -> hidden_size
|
||||
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
||||
|
||||
# Single transformer block for MTP
|
||||
# Use a dense MLP since this is just a single layer
|
||||
self.transformer_block = MTPTransformerBlock(config)
|
||||
|
||||
# Share embedding and lm_head with main model
|
||||
self._shared_embedding = shared_embedding
|
||||
self._shared_lm_head = shared_lm_head
|
||||
self._output_norm = output_norm
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
cache: KVCache | None = None,
|
||||
mask: mx.array | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass for MTP.
|
||||
|
||||
Args:
|
||||
hidden_state: Hidden state from main model [batch, seq_len, hidden_size]
|
||||
draft_token: Token to embed and combine with hidden state [batch, seq_len]
|
||||
cache: Optional KV cache for the MTP transformer block
|
||||
mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
tuple of (logits, new_hidden_state)
|
||||
"""
|
||||
# Get embedding of draft token
|
||||
embedding = self._shared_embedding(draft_token)
|
||||
|
||||
# Normalize hidden state and embedding
|
||||
h_norm = self.hnorm(hidden_state)
|
||||
e_norm = self.enorm(embedding)
|
||||
|
||||
# Project concatenated representation
|
||||
concatenated = mx.concatenate([h_norm, e_norm], axis=-1)
|
||||
projected = self.eh_proj(concatenated)
|
||||
|
||||
# Pass through single transformer block
|
||||
new_hidden = self.transformer_block(projected, mask=mask, cache=cache)
|
||||
|
||||
# Apply output norm and get logits
|
||||
normed_hidden = self._output_norm(new_hidden)
|
||||
logits = self._shared_lm_head(normed_hidden)
|
||||
|
||||
return logits, new_hidden
|
||||
|
||||
|
||||
class MTPTransformerBlock(nn.Module):
|
||||
"""Single transformer block for MTP.
|
||||
|
||||
This is similar to DeepseekV3DecoderLayer but uses a dense MLP
|
||||
instead of MoE since this is just for the single MTP layer.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV3Attention(config)
|
||||
# MTP uses dense MLP, not MoE
|
||||
self.mlp = DeepseekV3MLP(config)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: Any | None = None,
|
||||
) -> mx.array:
|
||||
"""Forward pass with residual connections."""
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
return h + r
|
||||
|
||||
|
||||
def extract_mtp_weights(weights: dict[str, mx.array]) -> dict[str, mx.array]:
|
||||
"""Extract MTP-specific weights from layer 61.
|
||||
|
||||
The MTP layer has these weight patterns:
|
||||
- model.layers.61.enorm.weight -> MTP embedding normalization
|
||||
- model.layers.61.hnorm.weight -> MTP hidden normalization
|
||||
- model.layers.61.eh_proj.weight -> MTP projection layer
|
||||
- model.layers.61.self_attn.* -> MTP attention
|
||||
- model.layers.61.input_layernorm.* -> MTP layer norms
|
||||
- model.layers.61.post_attention_layernorm.*
|
||||
- model.layers.61.mlp.* -> MTP MLP (dense, not MoE)
|
||||
|
||||
Args:
|
||||
weights: Full model weights dict
|
||||
|
||||
Returns:
|
||||
Dict of MTP-specific weights with keys renamed for MTPModule
|
||||
"""
|
||||
mtp_weights: dict[str, mx.array] = {}
|
||||
mtp_prefix = f"model.layers.{MTP_LAYER_INDEX}."
|
||||
|
||||
for key, value in weights.items():
|
||||
if key.startswith(mtp_prefix):
|
||||
# Remove the layer prefix to get relative path
|
||||
new_key = key[len(mtp_prefix) :]
|
||||
mtp_weights[new_key] = value
|
||||
|
||||
return mtp_weights
|
||||
|
||||
|
||||
def load_mtp_weights_into_module(
|
||||
mtp_module: MTPModule,
|
||||
mtp_weights: dict[str, mx.array],
|
||||
) -> None:
|
||||
"""Load extracted MTP weights into the MTPModule.
|
||||
|
||||
Args:
|
||||
mtp_module: The MTPModule instance to load weights into
|
||||
mtp_weights: Extracted MTP weights from extract_mtp_weights()
|
||||
"""
|
||||
# Map weight names to module attributes
|
||||
weight_mapping: dict[str, str] = {
|
||||
"enorm.weight": "enorm.weight",
|
||||
"hnorm.weight": "hnorm.weight",
|
||||
"eh_proj.weight": "eh_proj.weight",
|
||||
}
|
||||
|
||||
# Load direct mappings
|
||||
for src_name, dst_name in weight_mapping.items():
|
||||
if src_name in mtp_weights:
|
||||
parts = dst_name.split(".")
|
||||
obj: Any = mtp_module
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part) # pyright: ignore[reportAny]
|
||||
setattr(obj, parts[-1], mtp_weights[src_name]) # pyright: ignore[reportAny]
|
||||
|
||||
# Load transformer block weights (self_attn, mlp, layer norms)
|
||||
transformer_prefixes = [
|
||||
"self_attn",
|
||||
"mlp",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
]
|
||||
|
||||
for prefix in transformer_prefixes:
|
||||
for key, value in mtp_weights.items():
|
||||
if key.startswith(prefix):
|
||||
# Navigate to the correct attribute
|
||||
parts = key.split(".")
|
||||
obj = mtp_module.transformer_block
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part) # pyright: ignore[reportAny]
|
||||
setattr(obj, parts[-1], value) # pyright: ignore[reportAny]
|
||||
@@ -1,483 +0,0 @@
|
||||
"""MTP Speculative Decoding for DeepSeek V3.
|
||||
|
||||
This module implements speculative decoding using the Multi-Token Prediction (MTP)
|
||||
layer from DeepSeek V3. The key difference from standard speculative decoding is
|
||||
that MTP requires hidden states from the main model, not just token predictions.
|
||||
|
||||
Based on vLLM/SGLang research:
|
||||
- 81-82% acceptance rate with k=1
|
||||
- 1.5-2x speedup at low QPS
|
||||
"""
|
||||
|
||||
import functools
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models import cache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Generation stream for async operations
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
@dataclass
|
||||
class MTPGenerationResponse:
|
||||
"""Response from MTP speculative generation.
|
||||
|
||||
Attributes:
|
||||
text: The next segment of decoded text.
|
||||
token: The next token.
|
||||
logprobs: A vector of log probabilities.
|
||||
from_draft: Whether the token was generated by the MTP draft module.
|
||||
prompt_tokens: The number of tokens in the prompt.
|
||||
prompt_tps: The prompt processing tokens-per-second.
|
||||
generation_tokens: The number of generated tokens.
|
||||
generation_tps: The tokens-per-second for generation.
|
||||
peak_memory: The peak memory used so far in GB.
|
||||
finish_reason: The reason the response is being sent: "length", "stop" or None.
|
||||
"""
|
||||
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
from_draft: bool
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
"""Quantize KV cache entries if needed."""
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache): # pyright: ignore[reportAny]
|
||||
if (
|
||||
hasattr(c, "to_quantized") # pyright: ignore[reportAny]
|
||||
and hasattr(c, "offset") # pyright: ignore[reportAny]
|
||||
and c.offset >= quantized_kv_start # pyright: ignore[reportAny]
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
class ModelWithHiddenStates(nn.Module):
|
||||
"""Wrapper to extract hidden states before lm_head.
|
||||
|
||||
This wrapper allows capturing the hidden states from the transformer
|
||||
layers before the final lm_head projection, which is needed for MTP.
|
||||
"""
|
||||
|
||||
def __init__(self, base_model: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self._base = base_model
|
||||
|
||||
def forward_with_hidden(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass that returns both logits and hidden states.
|
||||
|
||||
Args:
|
||||
inputs: Input token ids
|
||||
model_cache: KV cache
|
||||
|
||||
Returns:
|
||||
Tuple of (logits, hidden_states)
|
||||
"""
|
||||
# Call the inner model (transformer layers + norm)
|
||||
hidden = cast(mx.array, self._base.model(inputs, model_cache)) # pyright: ignore[reportUnknownMemberType]
|
||||
# Get logits from lm_head
|
||||
logits = cast(mx.array, self._base.lm_head(hidden)) # pyright: ignore[reportUnknownMemberType]
|
||||
return logits, hidden
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> mx.array:
|
||||
"""Standard forward pass returning only logits."""
|
||||
return cast(mx.array, self._base(inputs, cache=model_cache))
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Module]:
|
||||
"""Access layers for cache creation."""
|
||||
return cast(list[nn.Module], self._base.layers)
|
||||
|
||||
|
||||
def mtp_speculative_generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
*,
|
||||
num_draft_tokens: int = 1,
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
mtp_cache: KVCache | None = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: int | None = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[tuple[int, mx.array, bool], None, None]:
|
||||
"""MTP speculative decoding generator.
|
||||
|
||||
Unlike standard speculative decoding where the draft model only needs tokens,
|
||||
MTP requires the hidden states from the main model. This generator:
|
||||
|
||||
1. Runs the main model to get logits AND hidden states
|
||||
2. Uses MTP module with hidden state + sampled token to predict next token
|
||||
3. Verifies MTP predictions with the main model
|
||||
4. Accepts/rejects based on matching
|
||||
|
||||
Args:
|
||||
prompt: The input prompt as token ids
|
||||
model: The main model (must support return_hidden=True)
|
||||
mtp_module: The MTP module for draft prediction
|
||||
num_draft_tokens: Number of tokens to draft (typically 1 for MTP)
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
sampler: Optional sampler function for token selection
|
||||
logits_processors: Optional list of logits processors
|
||||
prompt_cache: KV cache for the main model
|
||||
mtp_cache: KV cache for the MTP module
|
||||
prefill_step_size: Step size for prompt processing
|
||||
kv_bits: Bits for KV cache quantization
|
||||
kv_group_size: Group size for KV cache quantization
|
||||
quantized_kv_start: Step to begin cache quantization
|
||||
|
||||
Yields:
|
||||
Tuple of (token, logprobs, from_draft)
|
||||
"""
|
||||
y = prompt.astype(mx.uint32)
|
||||
prev_tokens: mx.array | None = None
|
||||
|
||||
# Wrap model to get hidden states
|
||||
wrapped_model = (
|
||||
model
|
||||
if isinstance(model, ModelWithHiddenStates)
|
||||
else ModelWithHiddenStates(model)
|
||||
)
|
||||
|
||||
# Create caches if needed
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
if mtp_cache is None:
|
||||
mtp_cache = KVCache()
|
||||
|
||||
final_sampler: Callable[[mx.array], mx.array] = (
|
||||
sampler if sampler is not None else (lambda x: mx.argmax(x, axis=-1))
|
||||
)
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _process_and_sample(
|
||||
tokens: mx.array | None,
|
||||
logits: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Process logits and sample tokens."""
|
||||
nonlocal logits_processors
|
||||
processed_logits = logits
|
||||
if logits_processors:
|
||||
for processor in logits_processors:
|
||||
processed_logits = processor(
|
||||
tokens if tokens is not None else mx.array([]), processed_logits
|
||||
)
|
||||
|
||||
logprobs = processed_logits - mx.logsumexp(
|
||||
processed_logits, axis=-1, keepdims=True
|
||||
)
|
||||
sampled = final_sampler(logprobs)
|
||||
return sampled, logprobs
|
||||
|
||||
def _main_model_step_with_hidden(
|
||||
input_y: mx.array,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Run main model step with hidden state return."""
|
||||
nonlocal prev_tokens
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
logits, hidden = wrapped_model.forward_with_hidden(
|
||||
input_y[None], prompt_cache
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
if logits_processors:
|
||||
prev_tokens = (
|
||||
mx.concatenate([prev_tokens, input_y])
|
||||
if prev_tokens is not None
|
||||
else input_y
|
||||
)
|
||||
|
||||
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
|
||||
return sampled, logprobs_result.squeeze(0), hidden[:, -1:, :]
|
||||
|
||||
def _mtp_draft(
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Generate draft token using MTP module."""
|
||||
with mx.stream(generation_stream):
|
||||
logits, new_hidden = mtp_module(
|
||||
hidden_state,
|
||||
draft_token,
|
||||
cache=mtp_cache,
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
sampled, _ = _process_and_sample(None, logits)
|
||||
return sampled, new_hidden
|
||||
|
||||
def _prefill(input_y: mx.array) -> mx.array:
|
||||
"""Prefill the prompt cache."""
|
||||
result_y = input_y
|
||||
while result_y.size > prefill_step_size:
|
||||
_ = wrapped_model.forward(result_y[:prefill_step_size][None], prompt_cache)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache]) # pyright: ignore[reportAny]
|
||||
result_y = result_y[prefill_step_size:]
|
||||
mx.clear_cache()
|
||||
return result_y
|
||||
|
||||
def _rewind_cache(num_draft: int, num_accept: int) -> None:
|
||||
"""Rewind caches after rejection."""
|
||||
cache.trim_prompt_cache(prompt_cache, num_draft - num_accept)
|
||||
|
||||
# Prefill phase
|
||||
with mx.stream(generation_stream):
|
||||
y = _prefill(y)
|
||||
|
||||
ntoks = 0
|
||||
num_draft = 0
|
||||
n_accepted = 0
|
||||
last_hidden: mx.array | None = None
|
||||
|
||||
try:
|
||||
# Initial step to get first token and hidden state
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
while ntoks < max_tokens:
|
||||
# Draft phase: use MTP to predict next token
|
||||
num_draft = min(max_tokens - ntoks - 1, num_draft_tokens)
|
||||
|
||||
if num_draft > 0 and last_hidden is not None: # pyright: ignore[reportUnnecessaryComparison]
|
||||
# Use MTP to draft
|
||||
draft_token, draft_hidden = _mtp_draft(last_hidden, y)
|
||||
mx.eval(draft_token, draft_hidden)
|
||||
|
||||
# Verify with main model
|
||||
# Feed the drafted token to main model
|
||||
verify_input = mx.concatenate([y, draft_token.flatten()])
|
||||
verify_sampled, verify_logprobs, new_hidden = (
|
||||
_main_model_step_with_hidden(verify_input)
|
||||
)
|
||||
mx.eval(verify_sampled, verify_logprobs, new_hidden)
|
||||
|
||||
# Check if draft matches verification
|
||||
draft_token_val = int(draft_token.item())
|
||||
verify_token_val = (
|
||||
int(verify_sampled[0].item())
|
||||
if verify_sampled.shape[0] > 1
|
||||
else int(verify_sampled.item())
|
||||
)
|
||||
|
||||
# Yield the current token (not from draft)
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
if draft_token_val == verify_token_val:
|
||||
# Draft accepted
|
||||
n_accepted += 1
|
||||
ntoks += 1
|
||||
draft_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
yield draft_token_val, draft_logprobs, True
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
# Continue with the token after the draft
|
||||
y = (
|
||||
verify_sampled[-1:]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[-1]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = new_hidden
|
||||
else:
|
||||
# Draft rejected - rewind and use verified token
|
||||
_rewind_cache(1, 0)
|
||||
y = (
|
||||
verify_sampled[:1]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = new_hidden[:, :1, :]
|
||||
else:
|
||||
# No drafting, just do normal generation
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
if ntoks % 256 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
finally:
|
||||
_rewind_cache(num_draft, n_accepted)
|
||||
|
||||
|
||||
def mtp_speculative_generate(
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str | mx.array | list[int],
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
num_draft_tokens: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int | None = None,
|
||||
) -> Generator[MTPGenerationResponse, None, None]:
|
||||
"""High-level MTP speculative generation with text output.
|
||||
|
||||
Args:
|
||||
model: The main model
|
||||
mtp_module: The MTP module for draft prediction
|
||||
tokenizer: Tokenizer for encoding/decoding
|
||||
prompt: Input prompt (string, array, or token list)
|
||||
max_tokens: Maximum tokens to generate
|
||||
sampler: Optional sampler function
|
||||
logits_processors: Optional logits processors
|
||||
prompt_cache: Optional KV cache
|
||||
num_draft_tokens: Number of draft tokens
|
||||
prefill_step_size: Prefill step size
|
||||
kv_group_size: KV group size
|
||||
kv_bits: KV bits
|
||||
|
||||
Yields:
|
||||
MTPGenerationResponse objects with text and metadata
|
||||
"""
|
||||
if not isinstance(prompt, mx.array):
|
||||
if isinstance(prompt, str):
|
||||
bos_token: str | None = getattr(tokenizer, "bos_token", None)
|
||||
add_special_tokens = bos_token is None or not prompt.startswith(
|
||||
str(bos_token)
|
||||
)
|
||||
encoded: list[int] = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
prompt = mx.array(encoded)
|
||||
else:
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
detokenizer: Any = tokenizer.detokenizer
|
||||
eos_token_ids: list[int] = getattr(tokenizer, "eos_token_ids", [])
|
||||
|
||||
token_generator = mtp_speculative_generate_step(
|
||||
prompt,
|
||||
model,
|
||||
mtp_module,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
prompt_tps = 0.0
|
||||
token = 0
|
||||
logprobs: mx.array = mx.array([0.0])
|
||||
from_draft = False
|
||||
n = 0
|
||||
|
||||
for n, (token, logprobs, from_draft) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = float(prompt.size) / prompt_time
|
||||
tic = time.perf_counter()
|
||||
|
||||
if token in eos_token_ids:
|
||||
break
|
||||
|
||||
detokenizer.add_token(token) # pyright: ignore[reportAny]
|
||||
if (n + 1) == max_tokens:
|
||||
break
|
||||
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment), # pyright: ignore[reportAny]
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
detokenizer.finalize() # pyright: ignore[reportAny]
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment), # pyright: ignore[reportAny]
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason="stop" if token in eos_token_ids else "length",
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
"""Tests for MTP module."""
|
||||
@@ -1,426 +0,0 @@
|
||||
"""Unit tests for MTP module components."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTP_LAYER_INDEX,
|
||||
MTPModule,
|
||||
MTPTransformerBlock,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
|
||||
class MockModelArgs:
|
||||
"""Mock ModelArgs for testing without importing deepseek_v3."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 256,
|
||||
intermediate_size: int = 512,
|
||||
num_attention_heads: int = 4,
|
||||
num_key_value_heads: int = 4,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
vocab_size: int = 1000,
|
||||
q_lora_rank: int | None = None,
|
||||
kv_lora_rank: int = 64,
|
||||
qk_rope_head_dim: int = 16,
|
||||
v_head_dim: int = 32,
|
||||
qk_nope_head_dim: int = 32,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
attention_bias: bool = False,
|
||||
max_position_embeddings: int = 2048,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.vocab_size = vocab_size
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
|
||||
class TestExtractMTPWeights:
|
||||
"""Tests for extract_mtp_weights function."""
|
||||
|
||||
def test_extracts_layer_61_weights(self) -> None:
|
||||
"""Should extract only layer 61 weights."""
|
||||
weights = {
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.61.enorm.weight": mx.ones((10,)),
|
||||
"model.layers.61.hnorm.weight": mx.ones((10,)) * 2,
|
||||
"model.layers.61.eh_proj.weight": mx.ones((10, 20)),
|
||||
"model.layers.62.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.embed_tokens.weight": mx.zeros((100, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 3
|
||||
assert "enorm.weight" in mtp_weights
|
||||
assert "hnorm.weight" in mtp_weights
|
||||
assert "eh_proj.weight" in mtp_weights
|
||||
# Check values are preserved
|
||||
assert mx.allclose(mtp_weights["enorm.weight"], mx.ones((10,)))
|
||||
assert mx.allclose(mtp_weights["hnorm.weight"], mx.ones((10,)) * 2)
|
||||
|
||||
def test_returns_empty_dict_when_no_layer_61(self) -> None:
|
||||
"""Should return empty dict when layer 61 doesn't exist."""
|
||||
weights = {
|
||||
"model.layers.0.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 0
|
||||
|
||||
def test_handles_nested_layer_61_weights(self) -> None:
|
||||
"""Should handle nested weight paths like self_attn.q_proj.weight."""
|
||||
weights = {
|
||||
f"model.layers.{MTP_LAYER_INDEX}.self_attn.q_a_proj.weight": mx.zeros(
|
||||
(10, 10)
|
||||
),
|
||||
f"model.layers.{MTP_LAYER_INDEX}.mlp.gate_proj.weight": mx.zeros((20, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert "self_attn.q_a_proj.weight" in mtp_weights
|
||||
assert "mlp.gate_proj.weight" in mtp_weights
|
||||
|
||||
|
||||
class TestMTPTransformerBlock:
|
||||
"""Tests for MTPTransformerBlock."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64, intermediate_size=128, num_attention_heads=2
|
||||
)
|
||||
|
||||
def test_forward_shape(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should preserve input shape."""
|
||||
# Skip if deepseek_v3 imports fail (CI without mlx_lm)
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
|
||||
output = block(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
def test_forward_with_mask(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should work with attention mask."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
# Create causal mask
|
||||
mask = mx.triu(mx.full((5, 5), float("-inf")), k=1)
|
||||
|
||||
output = block(x, mask=mask)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
|
||||
class TestMTPModule:
|
||||
"""Tests for MTPModule."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def shared_components(
|
||||
self, config: MockModelArgs
|
||||
) -> tuple[nn.Embedding, nn.Linear, nn.RMSNorm]:
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
return embedding, lm_head, output_norm
|
||||
|
||||
def test_initialization(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should initialize with correct components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
assert mtp.hnorm is not None
|
||||
assert mtp.enorm is not None
|
||||
assert mtp.eh_proj is not None
|
||||
assert mtp.transformer_block is not None
|
||||
|
||||
def test_forward_output_shapes(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""Forward pass should return correct output shapes."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
batch_size = 2
|
||||
seq_len = 1
|
||||
hidden_state = mx.random.normal((batch_size, seq_len, config.hidden_size))
|
||||
draft_token = mx.array([[5], [10]]) # [batch, seq_len]
|
||||
|
||||
logits, new_hidden = mtp(hidden_state, draft_token)
|
||||
|
||||
assert logits.shape == (batch_size, seq_len, config.vocab_size)
|
||||
assert new_hidden.shape == (batch_size, seq_len, config.hidden_size)
|
||||
|
||||
def test_shares_embedding_and_lm_head(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should use shared embedding and lm_head."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Verify they're the same objects
|
||||
assert mtp._shared_embedding is embedding # pyright: ignore[reportPrivateUsage]
|
||||
assert mtp._shared_lm_head is lm_head # pyright: ignore[reportPrivateUsage]
|
||||
assert mtp._output_norm is output_norm # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
class TestLoadMTPWeights:
|
||||
"""Tests for load_mtp_weights_into_module."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
def test_loads_norm_weights(self, config: MockModelArgs) -> None:
|
||||
"""Should load enorm and hnorm weights."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Create test weights
|
||||
test_enorm = mx.ones((config.hidden_size,)) * 3.0
|
||||
test_hnorm = mx.ones((config.hidden_size,)) * 5.0
|
||||
mtp_weights = {
|
||||
"enorm.weight": test_enorm,
|
||||
"hnorm.weight": test_hnorm,
|
||||
}
|
||||
|
||||
load_mtp_weights_into_module(mtp, mtp_weights)
|
||||
|
||||
assert mx.allclose(mtp.enorm.weight, test_enorm) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
assert mx.allclose(mtp.hnorm.weight, test_hnorm) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
|
||||
class TestSanitizePatch:
|
||||
"""Tests for the sanitize patching logic."""
|
||||
|
||||
def test_patch_preserves_layer_61(self) -> None:
|
||||
"""Patching sanitize should preserve layer 61 weights."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp, # pyright: ignore[reportPrivateUsage]
|
||||
_restore_deepseek_sanitize, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
deepseek_v3: Any = pytest.importorskip("mlx_lm.models.deepseek_v3") # pyright: ignore[reportAny]
|
||||
model_cls: Any = deepseek_v3.Model # pyright: ignore[reportAny]
|
||||
|
||||
# Get original sanitize behavior
|
||||
original_sanitize: Any = model_cls.sanitize # pyright: ignore[reportAny]
|
||||
|
||||
try:
|
||||
# Apply patch
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
# Note: we can't easily test the full sanitize without a real model
|
||||
# This test verifies the patch is applied
|
||||
assert model_cls.sanitize is not original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
# Verify restore worked
|
||||
assert model_cls.sanitize is original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
def test_restore_sanitize(self) -> None:
|
||||
"""Restoring sanitize should return to original behavior."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp, # pyright: ignore[reportPrivateUsage]
|
||||
_restore_deepseek_sanitize, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
deepseek_v3: Any = pytest.importorskip("mlx_lm.models.deepseek_v3") # pyright: ignore[reportAny]
|
||||
model_cls: Any = deepseek_v3.Model # pyright: ignore[reportAny]
|
||||
|
||||
original_sanitize: Any = model_cls.sanitize # pyright: ignore[reportAny]
|
||||
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is not original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
def test_double_patch_is_safe(self) -> None:
|
||||
"""Calling patch twice should be safe (idempotent)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp, # pyright: ignore[reportPrivateUsage]
|
||||
_restore_deepseek_sanitize, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
deepseek_v3: Any = pytest.importorskip("mlx_lm.models.deepseek_v3") # pyright: ignore[reportAny]
|
||||
model_cls: Any = deepseek_v3.Model # pyright: ignore[reportAny]
|
||||
|
||||
original_sanitize: Any = model_cls.sanitize # pyright: ignore[reportAny]
|
||||
|
||||
try:
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
patched_sanitize: Any = model_cls.sanitize # pyright: ignore[reportAny]
|
||||
|
||||
# Patch again - should be no-op
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is patched_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
class TestModelIdDetection:
|
||||
"""Tests for DeepSeek V3 model ID detection."""
|
||||
|
||||
def test_detects_deepseek_v3(self) -> None:
|
||||
"""Should detect DeepSeek V3 model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_might_be_deepseek_v3, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-V3")
|
||||
assert _might_be_deepseek_v3("deepseek-ai/deepseek-v3-base")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-V3-4bit")
|
||||
|
||||
def test_detects_deepseek_r1(self) -> None:
|
||||
"""Should detect DeepSeek R1 model IDs (also uses MTP)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_might_be_deepseek_v3, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-R1")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-R1-4bit")
|
||||
|
||||
def test_rejects_non_deepseek(self) -> None:
|
||||
"""Should reject non-DeepSeek model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_might_be_deepseek_v3, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
assert not _might_be_deepseek_v3("meta-llama/Llama-3-70B")
|
||||
assert not _might_be_deepseek_v3("mistralai/Mixtral-8x7B")
|
||||
assert not _might_be_deepseek_v3("deepseek-ai/DeepSeek-V2") # V2, not V3
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
"""Detection should be case insensitive."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_might_be_deepseek_v3, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
assert _might_be_deepseek_v3("DEEPSEEK-AI/DEEPSEEK-V3")
|
||||
assert _might_be_deepseek_v3("DeepSeek-AI/deepseek-v3")
|
||||
|
||||
|
||||
class TestFlattenParams:
|
||||
"""Tests for parameter flattening utility."""
|
||||
|
||||
def test_flattens_nested_dict(self) -> None:
|
||||
"""Should flatten nested parameter dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_flatten_params, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
params = {
|
||||
"model": {
|
||||
"layers": {
|
||||
"0": {
|
||||
"weight": mx.zeros((10,)),
|
||||
}
|
||||
},
|
||||
"embed": mx.ones((5,)),
|
||||
}
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert "model.layers.0.weight" in flat
|
||||
assert "model.embed" in flat
|
||||
assert mx.allclose(flat["model.layers.0.weight"], mx.zeros((10,)))
|
||||
assert mx.allclose(flat["model.embed"], mx.ones((5,)))
|
||||
|
||||
def test_handles_flat_dict(self) -> None:
|
||||
"""Should handle already-flat dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_flatten_params, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
params = {
|
||||
"weight": mx.zeros((10,)),
|
||||
"bias": mx.ones((10,)),
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert flat == params
|
||||
@@ -1,260 +0,0 @@
|
||||
"""Unit tests for MTP speculative decoding."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import (
|
||||
ModelWithHiddenStates,
|
||||
maybe_quantize_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
class MockModel(nn.Module):
|
||||
"""Mock model for testing speculative decoding."""
|
||||
|
||||
def __init__(self, hidden_size: int = 64, vocab_size: int = 100) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Create simple model components
|
||||
self.model = MockInnerModel(hidden_size)
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
self._layers = [nn.Linear(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list[object] | None = None,
|
||||
) -> mx.array:
|
||||
hidden = self.model(inputs, cache)
|
||||
return self.lm_head(hidden)
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Linear]:
|
||||
return self._layers
|
||||
|
||||
|
||||
class MockInnerModel(nn.Module):
|
||||
"""Mock inner model (like DeepseekV3Model)."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(100, hidden_size)
|
||||
self.norm = nn.RMSNorm(hidden_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list[object] | None = None,
|
||||
) -> mx.array:
|
||||
# Simple embedding + norm
|
||||
embedded = self.embed_tokens(inputs)
|
||||
return self.norm(embedded)
|
||||
|
||||
|
||||
class TestModelWithHiddenStates:
|
||||
"""Tests for ModelWithHiddenStates wrapper."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self) -> MockModel:
|
||||
return MockModel(hidden_size=64, vocab_size=100)
|
||||
|
||||
def test_forward_returns_logits(self, mock_model: MockModel) -> None:
|
||||
"""Standard forward should return logits."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits = wrapped.forward(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
|
||||
def test_forward_with_hidden_returns_tuple(self, mock_model: MockModel) -> None:
|
||||
"""Forward with hidden should return (logits, hidden)."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits, hidden = wrapped.forward_with_hidden(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
assert hidden.shape == (1, 3, mock_model.hidden_size)
|
||||
|
||||
def test_layers_property(self, mock_model: MockModel) -> None:
|
||||
"""Should expose layers property from base model."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
|
||||
assert wrapped.layers == mock_model.layers
|
||||
assert len(wrapped.layers) == 3
|
||||
|
||||
|
||||
class TestMaybeQuantizeKVCache:
|
||||
"""Tests for KV cache quantization."""
|
||||
|
||||
def test_no_quantization_when_bits_none(self) -> None:
|
||||
"""Should not quantize when kv_bits is None."""
|
||||
cache = [MockCache(offset=100)]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
cache,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=None,
|
||||
)
|
||||
|
||||
# Cache should be unchanged
|
||||
assert not hasattr(cache[0], "quantized")
|
||||
|
||||
def test_respects_quantized_kv_start(self) -> None:
|
||||
"""Should only quantize caches past the start threshold."""
|
||||
cache_below = MockCache(offset=30)
|
||||
cache_above = MockCache(offset=100)
|
||||
caches = [cache_below, cache_above]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
caches,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=4,
|
||||
)
|
||||
|
||||
# Only cache_above should be quantized
|
||||
assert not getattr(cache_below, "was_quantized", False)
|
||||
assert getattr(caches[1], "was_quantized", False)
|
||||
|
||||
|
||||
class MockCache:
|
||||
"""Mock KV cache for testing."""
|
||||
|
||||
def __init__(self, offset: int = 0) -> None:
|
||||
self.offset = offset
|
||||
self.was_quantized = False
|
||||
|
||||
def to_quantized(self, group_size: int, bits: int) -> "MockCache":
|
||||
quantized = MockCache(self.offset)
|
||||
quantized.was_quantized = True
|
||||
return quantized
|
||||
|
||||
|
||||
class TestSpeculativeDecodingLogic:
|
||||
"""Tests for the core speculative decoding logic."""
|
||||
|
||||
def test_draft_acceptance_identical_tokens(self) -> None:
|
||||
"""When draft matches verification, both should be accepted."""
|
||||
# This tests the logic, not the full generator
|
||||
draft_token = 42
|
||||
verify_token = 42
|
||||
|
||||
accepted = draft_token == verify_token
|
||||
assert accepted
|
||||
|
||||
def test_draft_rejection_different_tokens(self) -> None:
|
||||
"""When draft differs from verification, draft should be rejected."""
|
||||
draft_token = 42
|
||||
verify_token = 99
|
||||
|
||||
accepted = int(draft_token) == int(verify_token)
|
||||
assert not accepted
|
||||
|
||||
|
||||
class TestMTPGenerationResponse:
|
||||
"""Tests for MTPGenerationResponse dataclass."""
|
||||
|
||||
def test_response_creation(self) -> None:
|
||||
"""Should create response with all fields."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="Hello",
|
||||
token=42,
|
||||
logprobs=mx.array([0.1, 0.2]),
|
||||
from_draft=True,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=5,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert response.text == "Hello"
|
||||
assert response.token == 42
|
||||
assert response.from_draft is True
|
||||
assert response.finish_reason is None
|
||||
|
||||
def test_response_with_finish_reason(self) -> None:
|
||||
"""Should handle finish_reason."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="",
|
||||
token=0,
|
||||
logprobs=mx.array([0.0]),
|
||||
from_draft=False,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=100,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason="length",
|
||||
)
|
||||
|
||||
assert response.finish_reason == "length"
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the full MTP pipeline."""
|
||||
|
||||
def test_mtp_module_with_mock_model(self) -> None:
|
||||
"""Test MTP module can be created and run with mock components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Create mock config
|
||||
class MockConfig:
|
||||
hidden_size = 64
|
||||
intermediate_size = 128
|
||||
num_attention_heads = 2
|
||||
num_key_value_heads = 2
|
||||
rms_norm_eps = 1e-6
|
||||
q_lora_rank = None
|
||||
kv_lora_rank = 32
|
||||
qk_rope_head_dim = 8
|
||||
v_head_dim = 16
|
||||
qk_nope_head_dim = 16
|
||||
rope_theta = 10000.0
|
||||
rope_scaling = None
|
||||
attention_bias = False
|
||||
max_position_embeddings = 2048
|
||||
|
||||
config = MockConfig()
|
||||
embedding = nn.Embedding(100, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, 100, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
hidden = mx.random.normal((1, 1, config.hidden_size))
|
||||
token = mx.array([[5]])
|
||||
|
||||
logits, new_hidden = mtp(hidden, token)
|
||||
|
||||
assert logits.shape == (1, 1, 100)
|
||||
assert new_hidden.shape == (1, 1, config.hidden_size)
|
||||
# Verify outputs are valid (not NaN)
|
||||
# mx.isnan forces GPU evaluation which may fail on CI runners
|
||||
# with incompatible Metal toolchains (e.g. missing steel_gemm kernels)
|
||||
try:
|
||||
assert not mx.any(mx.isnan(logits))
|
||||
assert not mx.any(mx.isnan(new_hidden))
|
||||
except RuntimeError as e:
|
||||
if "Unable to load kernel" in str(e):
|
||||
pytest.skip(f"Metal kernel not available on this device: {e}")
|
||||
raise
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -24,7 +23,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
MTP_ENABLED,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -66,144 +64,6 @@ from exo.worker.runner.bootstrap import logger
|
||||
Group = mx.distributed.Group
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
# MTP (Multi-Token Prediction) support for DeepSeek V3
|
||||
MTP_LAYER_INDEX = 61
|
||||
_original_deepseek_sanitize: Callable[..., dict[str, Any]] | None = None
|
||||
|
||||
|
||||
def _is_deepseek_v3_model(model: nn.Module) -> bool:
|
||||
"""Check if the model is DeepSeek V3."""
|
||||
return hasattr(model, "model") and isinstance(model.model, DeepseekV3Model) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
def _might_be_deepseek_v3(model_id: str) -> bool:
|
||||
"""Check if model ID suggests this might be DeepSeek V3."""
|
||||
model_id_lower = model_id.lower()
|
||||
return "deepseek" in model_id_lower and (
|
||||
"v3" in model_id_lower or "r1" in model_id_lower
|
||||
)
|
||||
|
||||
|
||||
def _patch_deepseek_sanitize_for_mtp() -> None:
|
||||
"""Patch DeepSeek V3 Model.sanitize to preserve MTP layer weights."""
|
||||
global _original_deepseek_sanitize
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
if _original_deepseek_sanitize is not None:
|
||||
return
|
||||
|
||||
_original_deepseek_sanitize = DeepSeekV3Model.sanitize
|
||||
|
||||
def sanitize_with_mtp(
|
||||
self: DeepSeekV3Model, weights: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if _original_deepseek_sanitize is None:
|
||||
raise RuntimeError(
|
||||
"_original_deepseek_sanitize is None - patch not applied correctly"
|
||||
)
|
||||
original_result: dict[str, Any] = _original_deepseek_sanitize(self, weights)
|
||||
mtp_weights = {
|
||||
k: v
|
||||
for k, v in weights.items() # pyright: ignore[reportAny]
|
||||
if k.startswith(f"model.layers.{MTP_LAYER_INDEX}")
|
||||
}
|
||||
return {**original_result, **mtp_weights}
|
||||
|
||||
DeepSeekV3Model.sanitize = sanitize_with_mtp
|
||||
|
||||
|
||||
def _restore_deepseek_sanitize() -> None:
|
||||
"""Restore the original DeepSeek V3 sanitize method."""
|
||||
global _original_deepseek_sanitize
|
||||
if _original_deepseek_sanitize is None:
|
||||
return
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
DeepSeekV3Model.sanitize = _original_deepseek_sanitize
|
||||
_original_deepseek_sanitize = None
|
||||
|
||||
|
||||
def _flatten_params(
|
||||
params: dict[str, Any],
|
||||
prefix: str = "",
|
||||
) -> dict[str, mx.array]:
|
||||
"""Flatten nested parameter dict to flat dict with dot-separated keys."""
|
||||
result: dict[str, mx.array] = {}
|
||||
for key, value in params.items(): # pyright: ignore[reportAny]
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, mx.array):
|
||||
result[full_key] = value
|
||||
elif isinstance(value, dict):
|
||||
result.update(_flatten_params(cast(dict[str, Any], value), full_key))
|
||||
return result
|
||||
|
||||
|
||||
def _extract_mtp_module(model: nn.Module) -> Any | None:
|
||||
"""Extract MTP module from a loaded DeepSeek V3 model."""
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTPModule,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
try:
|
||||
inner_model: Any = getattr(model, "model", None)
|
||||
if inner_model is None or not hasattr(inner_model, "layers"): # pyright: ignore[reportAny]
|
||||
logger.debug("Model doesn't have expected structure for MTP extraction")
|
||||
return None
|
||||
|
||||
layers: list[nn.Module] = inner_model.layers # pyright: ignore[reportAny]
|
||||
if len(layers) <= MTP_LAYER_INDEX:
|
||||
logger.debug(
|
||||
f"Model has {len(layers)} layers, MTP layer {MTP_LAYER_INDEX} not found"
|
||||
)
|
||||
return None
|
||||
|
||||
config: Any = getattr(model, "args", None)
|
||||
if config is None:
|
||||
logger.debug("Could not get model config for MTP module")
|
||||
return None
|
||||
|
||||
embed_tokens: Any = getattr(inner_model, "embed_tokens", None) # pyright: ignore[reportAny]
|
||||
lm_head: Any = getattr(model, "lm_head", None)
|
||||
norm: Any = getattr(inner_model, "norm", None) # pyright: ignore[reportAny]
|
||||
|
||||
if embed_tokens is None or lm_head is None or norm is None:
|
||||
logger.debug("Could not get required model components for MTP")
|
||||
return None
|
||||
|
||||
mtp_module = MTPModule(
|
||||
config=config, # pyright: ignore[reportAny]
|
||||
shared_embedding=embed_tokens, # pyright: ignore[reportAny]
|
||||
shared_lm_head=lm_head, # pyright: ignore[reportAny]
|
||||
output_norm=norm, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
raw_params: dict[str, Any] = dict(model.parameters()) # type: ignore[arg-type]
|
||||
model_weights = _flatten_params(raw_params)
|
||||
mtp_weights = extract_mtp_weights(model_weights)
|
||||
|
||||
if not mtp_weights:
|
||||
logger.debug("No MTP weights found in model parameters")
|
||||
return None
|
||||
|
||||
load_mtp_weights_into_module(mtp_module, mtp_weights)
|
||||
|
||||
new_layers = [layer for i, layer in enumerate(layers) if i != MTP_LAYER_INDEX]
|
||||
inner_model.layers = new_layers # noqa: B010
|
||||
|
||||
logger.info(
|
||||
f"Extracted MTP module, main model now has {len(new_layers)} layers"
|
||||
)
|
||||
return mtp_module
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract MTP module: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -221,30 +81,6 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -336,52 +172,28 @@ def load_mlx_items(
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
model_id = bound_instance.bound_shard.model_card.model_id
|
||||
mtp_module = None
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
# Patch sanitize for MTP if this might be DeepSeek V3
|
||||
should_try_mtp = MTP_ENABLED and _might_be_deepseek_v3(model_id)
|
||||
if should_try_mtp:
|
||||
logger.info("Patching DeepSeek V3 sanitize for MTP weight preservation")
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
try:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=not should_try_mtp)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
|
||||
# Extract MTP module if available
|
||||
if should_try_mtp and _is_deepseek_v3_model(model):
|
||||
mtp_module = _extract_mtp_module(model)
|
||||
if mtp_module is not None:
|
||||
logger.info("Successfully extracted MTP module from DeepSeek V3")
|
||||
|
||||
finally:
|
||||
if should_try_mtp:
|
||||
_restore_deepseek_sanitize()
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
# Store MTP module on the model for later access
|
||||
if mtp_module is not None:
|
||||
model.mtp_module = mtp_module # noqa: B010
|
||||
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
@@ -541,7 +353,13 @@ def load_tokenizer_for_model_id(
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
|
||||
return TokenizerWrapper(
|
||||
hf_tokenizer,
|
||||
eos_token_ids=eos_token_ids,
|
||||
tool_call_start="<|tool_calls_section_begin|>",
|
||||
tool_call_end="<|tool_calls_section_end|>",
|
||||
tool_parser=_parse_kimi_tool_calls,
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
@@ -753,3 +571,61 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _parse_kimi_tool_calls(text: str):
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
_tool_call_split_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
|
||||
)
|
||||
|
||||
def _parse_single_tool(text: str) -> dict[str, Any]:
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError("No tool call found.")
|
||||
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
|
||||
func_name = func_name_match.group(2) # e.g. "get_weather"
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError("No tool call arguments found.")
|
||||
func_args = func_args_match.group(1)
|
||||
try:
|
||||
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
arg_dct = None
|
||||
|
||||
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
|
||||
|
||||
tool_matches = _tool_call_split_regex.findall(text)
|
||||
if tool_matches:
|
||||
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
|
||||
else:
|
||||
return [_parse_single_tool(text)]
|
||||
|
||||
@@ -33,6 +33,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -224,15 +225,22 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
await runner.start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -270,18 +278,18 @@ class Worker:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
await self._start_runner_task(modified_task)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
await self._start_runner_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
||||
async def _start_runner_task(self, task: Task):
|
||||
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||
await self.runners[
|
||||
instance.shard_assignments.node_to_runner[self.node_id]
|
||||
].start_task(task)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
@@ -320,8 +328,6 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -53,13 +54,14 @@ def plan(
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
)
|
||||
|
||||
|
||||
@@ -270,7 +272,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -284,7 +286,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -292,16 +294,33 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner_id, runner in runners.items():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id,
|
||||
cancelled_task_id=task.task_id,
|
||||
runner_id=runner_id,
|
||||
)
|
||||
|
||||
@@ -3,8 +3,8 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxDevice, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
@@ -15,6 +15,7 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
@@ -34,11 +35,20 @@ def entrypoint(
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Set MLX compute device before importing runner (which imports mlx.core at module scope)
|
||||
mlx_device = bound_instance.instance.mlx_device
|
||||
if mlx_device != MlxDevice.Auto:
|
||||
import mlx.core as mx
|
||||
|
||||
device = mx.cpu if mlx_device == MlxDevice.Cpu else mx.gpu
|
||||
mx.set_default_device(device)
|
||||
logger.info(f"MLX device set to: {mlx_device}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import resource
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Any, Callable, Literal
|
||||
from typing import Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -15,7 +15,6 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
@@ -88,9 +87,12 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
from .tool_parsers import ToolParser, make_mlx_parser
|
||||
|
||||
|
||||
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
|
||||
"""Check if this node is the primary output node for image generation.
|
||||
@@ -112,6 +114,7 @@ def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
||||
@@ -129,11 +132,16 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
model: Model | DistributedImageModel | None = None
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
tool_parser: ToolParser | None = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -146,6 +154,7 @@ def main(
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -191,19 +200,28 @@ def main(
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
|
||||
)
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
tool_parser = make_mlx_parser(
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
@@ -211,8 +229,6 @@ def main(
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -224,16 +240,31 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
t = time.monotonic()
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||
)
|
||||
if group is not None:
|
||||
check_for_cancel_every = int(
|
||||
mx.max(
|
||||
mx.distributed.all_gather(
|
||||
mx.array([check_for_cancel_every]), group=group
|
||||
)
|
||||
).item()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"runner checking for cancellation every {check_for_cancel_every} tokens"
|
||||
)
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
@@ -241,8 +272,8 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
@@ -262,9 +293,9 @@ def main(
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
assert check_for_cancel_every
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
@@ -274,7 +305,7 @@ def main(
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
@@ -289,34 +320,25 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
if isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
elif tool_parser:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
@@ -364,6 +386,7 @@ def main(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -388,7 +411,7 @@ def main(
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
assert image_model
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -401,7 +424,9 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
||||
|
||||
if is_primary_output:
|
||||
@@ -451,7 +476,7 @@ def main(
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
assert image_model
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -464,7 +489,9 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -523,14 +550,20 @@ def main(
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
was_cancelled = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if not was_cancelled:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
@@ -544,21 +577,8 @@ def get_gpt_oss_encoding():
|
||||
return encoding
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
for resp in responses:
|
||||
assert isinstance(resp, GenerationResponse)
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
):
|
||||
continue
|
||||
yield resp
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
@@ -615,9 +635,9 @@ def parse_gpt_oss(
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
@@ -738,218 +758,55 @@ def _process_image_response(
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
responses: Generator[GenerationResponse], tool_parser: ToolParser
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tool_call_start:
|
||||
if response.text.startswith(tool_parser.start_parsing):
|
||||
in_tool_call = True
|
||||
continue
|
||||
# assumption: the tool call end is one token
|
||||
if in_tool_call and response.text == tool_call_end:
|
||||
try:
|
||||
# tool_parser returns an arbitrarily nested python dictionary
|
||||
# we actually don't want the python dictionary, we just want to
|
||||
# parse the top level { function: ..., arguments: ... } structure
|
||||
# as we're just gonna hand it back to the api anyway
|
||||
parsed = tool_parser("".join(tool_call_text_parts).strip())
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if isinstance(parsed, list):
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
|
||||
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
ValidationError,
|
||||
ValueError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
# ValueError: our parsers raise this for malformed tool calls
|
||||
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
|
||||
)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
if response.text.endswith(tool_parser.end_parsing):
|
||||
# parse the actual tool calls from the tool call text
|
||||
parsed = tool_parser.parse_tool_calls(
|
||||
"".join(tool_call_text_parts).strip()
|
||||
)
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if parsed is not None:
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed, usage=response.usage, stats=response.stats
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
|
||||
)
|
||||
response.text = "".join(tool_call_text_parts)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if response.finish_reason is not None:
|
||||
logger.info(
|
||||
"toll call parsing interrupted, yield partial tool call as text"
|
||||
"tool call parsing interrupted, yield partial tool call as text"
|
||||
)
|
||||
yield GenerationResponse(
|
||||
text=tool_call_start + "".join(tool_call_text_parts),
|
||||
token=0,
|
||||
finish_reason=response.finish_reason,
|
||||
usage=None,
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"text": "".join(tool_call_text_parts),
|
||||
"token": 0,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
continue
|
||||
|
||||
# fallthrough
|
||||
yield response
|
||||
|
||||
|
||||
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Version of to-be-upstreamed kimi-k2 tool parser
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
||||
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
|
||||
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
|
||||
tool_call_start = "<|tool_call_begin|>"
|
||||
tool_call_end = "<|tool_call_end|>"
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
original_func_name = func_name_match.group(1)
|
||||
tool_id = func_name_match.group(2)
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = original_func_name[original_func_name.find(".") + 1 :]
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError(f"Could not parse function args from tool call: {text!r}")
|
||||
func_args = func_args_match.group(1)
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
return dict(
|
||||
id=f"{original_func_name}:{tool_id}",
|
||||
name=func_name,
|
||||
arguments=arg_dct, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
|
||||
_func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
tool_call_start = "<tool_call>"
|
||||
tool_call_end = "</tool_call>"
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[Any] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools: # pyright: ignore[reportAny]
|
||||
func = tool["function"] # pyright: ignore[reportAny]
|
||||
if func["name"] == tool_name:
|
||||
params = func["parameters"] # pyright: ignore[reportAny]
|
||||
if params is None:
|
||||
return False
|
||||
props = params.get("properties", {}) # pyright: ignore[reportAny]
|
||||
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
|
||||
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
|
||||
return arg_type == "string" # pyright: ignore[reportAny]
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: list[Any] | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
pairs = _func_arg_regex.findall(text)
|
||||
arg_dct: dict[str, Any] = {}
|
||||
for key, value in pairs: # pyright: ignore[reportAny]
|
||||
arg_key = key.strip() # pyright: ignore[reportAny]
|
||||
arg_val = value.strip() # pyright: ignore[reportAny]
|
||||
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
|
||||
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
|
||||
arg_dct[arg_key] = arg_val
|
||||
return dict(name=func_name, arguments=arg_dct)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
and ((args := obj.get("arguments")) is not None)
|
||||
and isinstance(name, str)
|
||||
):
|
||||
raw_id: object = obj.get("id")
|
||||
extra = {"id": str(raw_id)} if raw_id is not None else {}
|
||||
return ToolCallItem(
|
||||
**extra,
|
||||
name=name,
|
||||
arguments=json.dumps(args),
|
||||
)
|
||||
else:
|
||||
raise ValidationError
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -47,9 +47,11 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -60,8 +62,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -69,6 +71,7 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -83,6 +86,7 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -97,6 +101,8 @@ class RunnerSupervisor:
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
@@ -112,14 +118,6 @@ class RunnerSupervisor:
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
@@ -141,6 +139,17 @@ class RunnerSupervisor:
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
with anyio.move_on_after(0.5) as scope:
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
if scope.cancel_called:
|
||||
logger.error("RunnerSupervisor cancel pipe blocked")
|
||||
await self._check_runner(TimeoutError("cancel pipe blocked"))
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
try:
|
||||
|
||||
72
src/exo/worker/runner/tool_parsers.py
Normal file
72
src/exo/worker/runner/tool_parsers.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolParser:
|
||||
start_parsing: str
|
||||
end_parsing: str
|
||||
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
|
||||
|
||||
|
||||
def make_mlx_parser(
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
) -> ToolParser:
|
||||
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix(tool_call_start)
|
||||
text = text.removesuffix(tool_call_end)
|
||||
parsed = tool_parser(text)
|
||||
if isinstance(parsed, list):
|
||||
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
|
||||
else:
|
||||
return [ToolCallItem.model_validate(_flatten(parsed))]
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return ToolParser(
|
||||
start_parsing=tool_call_start,
|
||||
end_parsing=tool_call_end,
|
||||
parse_tool_calls=parse_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
# TODO / example code:
|
||||
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix("<tool_call>")
|
||||
text = text.removesuffix("</tool_call>")
|
||||
top_level = {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else v
|
||||
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
|
||||
}
|
||||
return [ToolCallItem.model_validate(top_level)]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _flatten(p: dict[str, Any]) -> dict[str, str]:
|
||||
return {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
|
||||
for k, v in p.items() # pyright: ignore[reportAny]
|
||||
}
|
||||
|
||||
|
||||
json_tool_parser = ToolParser(
|
||||
start_parsing="<tool_call>",
|
||||
end_parsing="</tool_call>",
|
||||
parse_tool_calls=_parse_json_calls,
|
||||
)
|
||||
|
||||
|
||||
def infer_tool_parser(chat_template: str) -> ToolParser | None:
|
||||
"""Attempt to auto-infer a tool parser from the chat template."""
|
||||
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
|
||||
return json_tool_parser
|
||||
return None
|
||||
@@ -1,7 +1,9 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
import unittest.mock
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
@@ -19,6 +21,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -113,6 +116,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
@@ -163,6 +167,7 @@ def _run(tasks: Iterable[Task]):
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
||||
event_sender = EventCollector()
|
||||
|
||||
with task_sender:
|
||||
@@ -173,8 +178,16 @@ def _run(tasks: Iterable[Task]):
|
||||
# this is some c++ nonsense
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
|
||||
with unittest.mock.patch(
|
||||
"exo.worker.runner.runner.mx.distributed.all_gather",
|
||||
make_nothin(mx.array([1])),
|
||||
):
|
||||
mlx_runner.main(
|
||||
bound_instance,
|
||||
event_sender, # pyright: ignore[reportArgumentType]
|
||||
task_receiver,
|
||||
cancel_receiver,
|
||||
)
|
||||
|
||||
return event_sender.events
|
||||
|
||||
|
||||
@@ -5,12 +5,13 @@ from typing import Any
|
||||
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||
from exo.worker.runner.runner import parse_tool_calls
|
||||
from exo.worker.runner.tool_parsers import make_mlx_parser
|
||||
|
||||
|
||||
def _make_responses(
|
||||
texts: list[str],
|
||||
finish_on_last: bool = True,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Create a sequence of GenerationResponses from text strings."""
|
||||
for i, text in enumerate(texts):
|
||||
is_last = i == len(texts) - 1
|
||||
@@ -22,10 +23,13 @@ def _make_responses(
|
||||
)
|
||||
|
||||
|
||||
def _dummy_parser(text: str) -> dict[str, Any]:
|
||||
def _dummier_parser(text: str) -> dict[str, Any]:
|
||||
return {"name": "test_fn", "arguments": {"arg": text}}
|
||||
|
||||
|
||||
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
|
||||
|
||||
|
||||
class TestParseToolCalls:
|
||||
"""Tests for parse_tool_calls generator."""
|
||||
|
||||
@@ -35,8 +39,6 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts, finish_on_last=False),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_dummy_parser,
|
||||
)
|
||||
)
|
||||
@@ -50,8 +52,6 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_dummy_parser,
|
||||
)
|
||||
)
|
||||
@@ -76,9 +76,7 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts, finish_on_last=False),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_failing_parser,
|
||||
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user