Compare commits

...

3 Commits

Author SHA1 Message Date
Alex Cheema
72fca71522 use model.make_cache in make_kv_cache 2025-12-30 17:46:13 +00:00
Alex Cheema
16e2bfd3b3 log EXO_LIBP2P_NAMESPACE on start 2025-12-30 04:08:47 +00:00
Alex Cheema
ade3ee7ec5 fix warmup order. should be rank!=0 then rank=0 2025-12-30 03:29:34 +00:00
3 changed files with 10 additions and 5 deletions

View File

@@ -1,5 +1,6 @@
import argparse
import multiprocessing as mp
import os
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -194,6 +195,7 @@ def main():
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
node = anyio.run(Node.create, args)
anyio.run(node.run)

View File

@@ -343,6 +343,10 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
if hasattr(model, "make_cache"):
logger.info(f"Using make_cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -235,9 +235,8 @@ def _ready_to_warmup(
assert device_rank < world_size
assert device_rank >= 0
# TODO: Ensure these align with MLX distributeds expectations.
# Rank < n-1
accepting_ranks_ready = device_rank < world_size - 1 and all(
# Rank != 0
accepting_ranks_ready = device_rank > 0 and all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
@@ -245,8 +244,8 @@ def _ready_to_warmup(
for global_runner_id in shard_assignments.runner_to_shard
)
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
# Rank = 0
connecting_rank_ready = device_rank == 0 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id