mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 18:58:39 -05:00
Compare commits
2 Commits
remove-pyt
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ccf4d91d55 | ||
|
|
811a4d80bd |
@@ -41,7 +41,7 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.7.dev20260220+13998a05"; in
|
||||
version = let v = "0.30.7.dev20260224+5289547a"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
@@ -49,8 +49,8 @@ let
|
||||
src = fetchFromGitHub {
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "13998a054715edcdc93618fb1496c79c7c25ff7c";
|
||||
hash = "sha256-fAqA3hFwNBx7FcoGnhQsIFpAIRbC2EerACm4Fvne0Cc=";
|
||||
rev = "5289547ada1cddda2b9716baf6a077a906d02189";
|
||||
hash = "sha256-Zp9Jln7+Fpn79OfnIdiIVYzQDpih9lHrKtKJadh+c0I=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -128,12 +128,25 @@ class PipelineFirstLayer(CustomMlxLayer):
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if self.r != 0:
|
||||
import time as _time
|
||||
_t0 = _time.perf_counter()
|
||||
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
|
||||
if self.is_prefill:
|
||||
# We want to avoid GPU timeout errors by evalling the distributed operation
|
||||
# so that it stays on CPU, which does not have a timeout.
|
||||
mx.eval(x)
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
_elapsed = _time.perf_counter() - _t0
|
||||
if _elapsed > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineFirstLayer recv_like+eval took {_elapsed:.4f}s (SLOW)")
|
||||
_t0_layer = _time.perf_counter() if self.r != 0 else None
|
||||
result = self.original_layer(x, *args, **kwargs)
|
||||
if _t0_layer is not None:
|
||||
_elapsed_layer = _time.perf_counter() - _t0_layer
|
||||
if _elapsed_layer > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineFirstLayer original_layer took {_elapsed_layer:.4f}s (SLOW)")
|
||||
return result
|
||||
|
||||
|
||||
class PipelineLastLayer(CustomMlxLayer):
|
||||
@@ -152,13 +165,20 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
self.is_prefill: bool = False
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
import time as _time
|
||||
cache = self.original_layer_signature.bind_partial(
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
_t0 = _time.perf_counter()
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
_elapsed = _time.perf_counter() - _t0
|
||||
if _elapsed > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineLastLayer original_layer took {_elapsed:.4f}s (SLOW)")
|
||||
|
||||
if self.r != self.s - 1:
|
||||
_t0 = _time.perf_counter()
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
@@ -171,11 +191,20 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
mx.eval(output)
|
||||
if cache is not None:
|
||||
mx.eval(_cache.keys) # type: ignore
|
||||
_elapsed = _time.perf_counter() - _t0
|
||||
if _elapsed > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineLastLayer send+eval took {_elapsed:.4f}s (SLOW)")
|
||||
|
||||
if not self.is_prefill:
|
||||
_t0 = _time.perf_counter()
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
]
|
||||
_elapsed = _time.perf_counter() - _t0
|
||||
if _elapsed > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineLastLayer all_gather took {_elapsed:.4f}s (SLOW)")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -94,14 +94,20 @@ def prefill(
|
||||
if on_prefill_progress is not None:
|
||||
on_prefill_progress(processed, total)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
logger.warning(f"[PREFILL] set_pipeline_prefill(True) took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
t0 = time.perf_counter()
|
||||
mx_barrier(group)
|
||||
logger.info("Starting prefill")
|
||||
logger.warning(f"[PREFILL] mx_barrier (pre-prefill) took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
logger.warning("[PREFILL] Starting prefill via stream_generate")
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
try:
|
||||
t0 = time.perf_counter()
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -114,15 +120,19 @@ def prefill(
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
logger.warning(f"[PREFILL] stream_generate first yield took {time.perf_counter() - t0:.4f}s")
|
||||
break # Stop after first iteration - cache is now filled
|
||||
except PrefillCancelled:
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
raise
|
||||
|
||||
t0 = time.perf_counter()
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
logger.warning(f"[PREFILL] set_pipeline_prefill(False) took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
# stream_generate added 1 extra generated token to the cache, so we should trim it.
|
||||
# Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.
|
||||
t0 = time.perf_counter()
|
||||
pre_gen = deepcopy(snapshots[-2]) if has_ssm else None
|
||||
for i, c in enumerate(cache):
|
||||
if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
@@ -132,11 +142,12 @@ def prefill(
|
||||
else:
|
||||
assert not isinstance(c, (ArraysCache, RotatingKVCache))
|
||||
c.trim(2) # pyright: ignore[reportUnknownMemberType]
|
||||
logger.warning(f"[PREFILL] cache trim took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
logger.debug(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
logger.warning(
|
||||
f"[PREFILL] complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
# Exclude the last snapshot
|
||||
@@ -324,6 +335,8 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
logger.warning(f"[GENERATE] calling prefill with {len(prompt_tokens) - 1} tokens")
|
||||
t_prefill_start = time.perf_counter()
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
@@ -333,6 +346,7 @@ def mlx_generate(
|
||||
group,
|
||||
on_prefill_progress,
|
||||
)
|
||||
logger.warning(f"[GENERATE] prefill() returned in {time.perf_counter() - t_prefill_start:.4f}s")
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
# stream_generate starts from the last token
|
||||
@@ -348,9 +362,12 @@ def mlx_generate(
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
|
||||
logger.info("Starting decode")
|
||||
logger.warning("[GENERATE] Starting decode")
|
||||
t0 = time.perf_counter()
|
||||
mx_barrier(group)
|
||||
logger.warning(f"[GENERATE] mx_barrier (pre-decode) took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
_decode_token_start = time.perf_counter()
|
||||
for completion_tokens, out in enumerate(
|
||||
stream_generate(
|
||||
model=model,
|
||||
@@ -366,6 +383,9 @@ def mlx_generate(
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
_decode_token_elapsed = time.perf_counter() - _decode_token_start
|
||||
if _decode_token_elapsed > 1.0:
|
||||
logger.warning(f"[DECODE] token {completion_tokens} took {_decode_token_elapsed:.4f}s (SLOW)")
|
||||
generated_text_parts.append(out.text)
|
||||
accumulated_text += out.text
|
||||
|
||||
@@ -488,9 +508,12 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
if is_done:
|
||||
t0 = time.perf_counter()
|
||||
mx_barrier(group)
|
||||
logger.warning(f"[GENERATE] mx_barrier (post-decode) took {time.perf_counter() - t0:.4f}s")
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
_decode_token_start = time.perf_counter()
|
||||
|
||||
@@ -276,31 +276,38 @@ def main(
|
||||
_task_id: TaskId = task.task_id,
|
||||
_group: mx.distributed.Group | None = group,
|
||||
) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=PrefillProgressChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (_task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, _group):
|
||||
raise PrefillCancelled()
|
||||
time.sleep(0.2)
|
||||
return None
|
||||
# if device_rank == 0:
|
||||
# event_sender.send(
|
||||
# ChunkGenerated(
|
||||
# command_id=command_id,
|
||||
# chunk=PrefillProgressChunk(
|
||||
# model=shard_metadata.model_card.model_id,
|
||||
# processed_tokens=processed,
|
||||
# total_tokens=total,
|
||||
# ),
|
||||
# )
|
||||
# )
|
||||
# cancelled_tasks.update(cancel_receiver.collect())
|
||||
# want_to_cancel = (_task_id in cancelled_tasks) or (
|
||||
# TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
# )
|
||||
# if mx_any(want_to_cancel, _group):
|
||||
# raise PrefillCancelled()
|
||||
|
||||
try:
|
||||
import time as _time
|
||||
_runner_req_start = _time.perf_counter()
|
||||
_check_for_debug_prompts(task_params)
|
||||
|
||||
# Build prompt once - used for both generation and thinking detection
|
||||
_t0 = _time.perf_counter()
|
||||
prompt = apply_chat_template(tokenizer, task_params)
|
||||
logger.warning(f"[RUNNER] apply_chat_template took {_time.perf_counter() - _t0:.4f}s")
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
logger.warning("[RUNNER] calling mlx_generate")
|
||||
mlx_generator = mlx_generate(
|
||||
model=cast(Model, inference_model),
|
||||
tokenizer=tokenizer,
|
||||
@@ -332,6 +339,8 @@ def main(
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = check_for_cancel_every
|
||||
logger.warning("[RUNNER] starting token iteration loop")
|
||||
_runner_token_start = _time.perf_counter()
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
@@ -413,6 +422,7 @@ def main(
|
||||
)
|
||||
raise
|
||||
|
||||
logger.warning(f"[RUNNER] request complete in {_time.perf_counter() - _runner_req_start:.4f}s total")
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
|
||||
|
||||
279
tmp/reproduce_gpu_lock.py
Normal file
279
tmp/reproduce_gpu_lock.py
Normal file
@@ -0,0 +1,279 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# ///
|
||||
"""Reproduce GPU lock issue with mlx-community/Llama-3.2-1B-Instruct-4bit.
|
||||
|
||||
Starts exo or mlx_lm.server, then sends repeated chat completions
|
||||
until a request stalls for >5 seconds (indicating a GPU lock).
|
||||
|
||||
Usage:
|
||||
uv run tmp/reproduce_gpu_lock.py # use exo (default)
|
||||
uv run tmp/reproduce_gpu_lock.py --mlx-lm # use mlx_lm.server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import uuid
|
||||
|
||||
MODEL_ID = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
MODEL_PATH = os.path.expanduser("~/.exo/models/mlx-community--Llama-3.2-1B-Instruct-4bit")
|
||||
STALL_THRESHOLD_S = 5.0
|
||||
|
||||
server_proc = None
|
||||
base_url = ""
|
||||
|
||||
|
||||
def cleanup(*_):
|
||||
if server_proc and server_proc.poll() is None:
|
||||
print("\nStopping server...")
|
||||
server_proc.terminate()
|
||||
try:
|
||||
server_proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
server_proc.kill()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, cleanup)
|
||||
signal.signal(signal.SIGTERM, cleanup)
|
||||
|
||||
|
||||
def api_get(path, timeout=30):
|
||||
req = urllib.request.Request(f"{base_url}{path}")
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read())
|
||||
|
||||
|
||||
def api_post(path, body, timeout=300):
|
||||
data = json.dumps(body).encode()
|
||||
req = urllib.request.Request(f"{base_url}{path}", data=data, headers={"Content-Type": "application/json"})
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read())
|
||||
|
||||
|
||||
def wait_for_api(max_wait=120):
|
||||
print("Waiting for API to be ready...", flush=True)
|
||||
start = time.time()
|
||||
while time.time() - start < max_wait:
|
||||
try:
|
||||
api_get("/v1/models", timeout=5)
|
||||
print("API is ready.", flush=True)
|
||||
return
|
||||
except Exception:
|
||||
time.sleep(2)
|
||||
print("ERROR: API did not become ready in time.", flush=True)
|
||||
cleanup()
|
||||
|
||||
|
||||
def create_instance(max_wait=120):
|
||||
print(f"Waiting for valid placements for {MODEL_ID}...", flush=True)
|
||||
start = time.time()
|
||||
valid = []
|
||||
while time.time() - start < max_wait:
|
||||
try:
|
||||
previews = api_get(f"/instance/previews?model_id={MODEL_ID}")
|
||||
valid = [p for p in previews.get("previews", []) if p.get("error") is None and p.get("instance") is not None]
|
||||
if valid:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(3)
|
||||
if not valid:
|
||||
print("ERROR: No valid placements found after waiting.", flush=True)
|
||||
cleanup()
|
||||
|
||||
instance = valid[0]["instance"]
|
||||
print(f"Creating instance (sharding={valid[0].get('sharding')}, meta={valid[0].get('instance_meta')})...", flush=True)
|
||||
resp = api_post("/instance", {"instance": instance})
|
||||
print(f"Instance creation requested: {resp.get('message')} (command_id={resp.get('command_id')})", flush=True)
|
||||
return instance.get("id") or instance.get("instance_id")
|
||||
|
||||
|
||||
def wait_for_instance(max_wait=120):
|
||||
print("Waiting for instance to be ready...", flush=True)
|
||||
start = time.time()
|
||||
while time.time() - start < max_wait:
|
||||
try:
|
||||
state = api_get("/state", timeout=10)
|
||||
instances = state.get("instances") or state.get("model_instances") or {}
|
||||
if instances:
|
||||
print(f"Instance ready. ({len(instances)} instance(s) in state)", flush=True)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(3)
|
||||
print("WARNING: Timed out waiting for instance in state. Proceeding anyway...", flush=True)
|
||||
|
||||
|
||||
TOPICS = [
|
||||
"the weather", "cats", "space", "pizza", "music", "ocean", "mountains",
|
||||
"robots", "books", "coffee", "trains", "clouds", "birds", "fire",
|
||||
"ice cream", "trees", "rivers", "stars", "thunder", "gardens",
|
||||
]
|
||||
|
||||
def send_chat(request_num):
|
||||
topic = random.choice(TOPICS)
|
||||
nonce = random.randint(1000, 9999)
|
||||
body = {
|
||||
"model": MODEL_ID,
|
||||
"messages": [{"role": "user", "content": f"Say something about {topic} in one sentence. ({nonce})"}],
|
||||
"stream": False,
|
||||
"max_tokens": 64,
|
||||
}
|
||||
start = time.time()
|
||||
resp = api_post("/v1/chat/completions", body, timeout=600)
|
||||
elapsed = time.time() - start
|
||||
return elapsed, resp
|
||||
|
||||
|
||||
def start_exo():
|
||||
global server_proc
|
||||
machine_id = hashlib.sha256(f"{platform.node()}-{uuid.getnode()}".encode()).hexdigest()[:12]
|
||||
namespace = f"gpu-lock-repro-{machine_id}"
|
||||
|
||||
log_file = open("/tmp/exo_gpu_lock_repro.log", "w", buffering=1)
|
||||
print(f"\nStarting exo (namespace={namespace})...", flush=True)
|
||||
print(f"Log: /tmp/exo_gpu_lock_repro.log", flush=True)
|
||||
print(f" tail -f /tmp/exo_gpu_lock_repro.log (in another terminal to watch)", flush=True)
|
||||
env = {**os.environ, "EXO_LIBP2P_NAMESPACE": namespace, "PYTHONUNBUFFERED": "1"}
|
||||
server_proc = subprocess.Popen(
|
||||
["uv", "run", "exo"],
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env,
|
||||
)
|
||||
print(f"exo started (pid={server_proc.pid})", flush=True)
|
||||
|
||||
wait_for_api()
|
||||
create_instance()
|
||||
wait_for_instance()
|
||||
|
||||
|
||||
def start_mlx_lm():
|
||||
global server_proc
|
||||
log_file = open("/tmp/mlx_lm_gpu_lock_repro.log", "w", buffering=1)
|
||||
print(f"\nStarting mlx_lm.server on port 8080...", flush=True)
|
||||
print(f" Model path: {MODEL_PATH}", flush=True)
|
||||
print(f"Log: /tmp/mlx_lm_gpu_lock_repro.log", flush=True)
|
||||
print(f" tail -f /tmp/mlx_lm_gpu_lock_repro.log (in another terminal to watch)", flush=True)
|
||||
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
|
||||
server_proc = subprocess.Popen(
|
||||
["uv", "run", "mlx_lm.server", "--model", MODEL_PATH, "--port", "8080"],
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env,
|
||||
cwd=os.path.expanduser("~/mlx-lm"),
|
||||
)
|
||||
print(f"mlx_lm.server started (pid={server_proc.pid})", flush=True)
|
||||
wait_for_api()
|
||||
|
||||
|
||||
def chat_loop():
|
||||
print("\n" + "-" * 60, flush=True)
|
||||
print("Starting chat completion loop. Watching for stalls...", flush=True)
|
||||
print("-" * 60 + "\n", flush=True)
|
||||
|
||||
timings = []
|
||||
request_num = 0
|
||||
|
||||
while True:
|
||||
request_num += 1
|
||||
print(f" [#{request_num}] sending...", end="", flush=True)
|
||||
req_start = time.time()
|
||||
|
||||
done_event = threading.Event()
|
||||
def print_waiting():
|
||||
while not done_event.is_set():
|
||||
if done_event.wait(5):
|
||||
break
|
||||
elapsed_so_far = time.time() - req_start
|
||||
print(f" ({elapsed_so_far:.0f}s)", end="", flush=True)
|
||||
watcher = threading.Thread(target=print_waiting, daemon=True)
|
||||
watcher.start()
|
||||
|
||||
try:
|
||||
elapsed, resp = send_chat(request_num)
|
||||
except Exception as e:
|
||||
done_event.set()
|
||||
print(f" ERROR after {time.time() - req_start:.1f}s: {e}", flush=True)
|
||||
time.sleep(2)
|
||||
continue
|
||||
finally:
|
||||
done_event.set()
|
||||
|
||||
timings.append(elapsed)
|
||||
content = ""
|
||||
try:
|
||||
content = resp["choices"][0]["message"]["content"][:80]
|
||||
except (KeyError, IndexError):
|
||||
content = "<no content>"
|
||||
|
||||
print(f" {elapsed:.2f}s | {content}", flush=True)
|
||||
|
||||
if elapsed > STALL_THRESHOLD_S:
|
||||
print("\n", flush=True)
|
||||
print("!" * 60, flush=True)
|
||||
print("!" * 60, flush=True)
|
||||
print("!!!", flush=True)
|
||||
print(f"!!! GPU LOCK DETECTED on request #{request_num}", flush=True)
|
||||
print(f"!!! Elapsed: {elapsed:.2f}s (threshold: {STALL_THRESHOLD_S}s)", flush=True)
|
||||
print("!!!", flush=True)
|
||||
print("!" * 60, flush=True)
|
||||
print("!" * 60, flush=True)
|
||||
print(f"\nTotal requests sent: {request_num}", flush=True)
|
||||
print(f"Average time (all): {sum(timings) / len(timings):.2f}s", flush=True)
|
||||
normal = [t for t in timings if t <= STALL_THRESHOLD_S]
|
||||
if normal:
|
||||
print(f"Average time (normal): {sum(normal) / len(normal):.2f}s", flush=True)
|
||||
print(f"Max time: {max(timings):.2f}s", flush=True)
|
||||
print(f"Min time: {min(timings):.2f}s", flush=True)
|
||||
print("\nAll timings:", flush=True)
|
||||
for i, t in enumerate(timings, 1):
|
||||
marker = " <<<< STALL" if t > STALL_THRESHOLD_S else ""
|
||||
print(f" #{i}: {t:.2f}s{marker}", flush=True)
|
||||
print(f"\nServer still running (pid={server_proc.pid}). Continuing... Ctrl+C to stop.", flush=True)
|
||||
print("-" * 60 + "\n", flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
global base_url
|
||||
|
||||
parser = argparse.ArgumentParser(description="Reproduce GPU lock issue")
|
||||
parser.add_argument("--mlx-lm", action="store_true", help="Use mlx_lm.server instead of exo")
|
||||
parser.add_argument("--port", type=int, default=None, help="Override server port")
|
||||
args = parser.parse_args()
|
||||
|
||||
mode = "mlx_lm" if args.mlx_lm else "exo"
|
||||
port = args.port or (8080 if args.mlx_lm else 52415)
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
print("=" * 60, flush=True)
|
||||
print(" GPU Lock Reproduction Script", flush=True)
|
||||
print(f" Mode: {mode}", flush=True)
|
||||
print(f" Model: {MODEL_ID}", flush=True)
|
||||
print(f" API: {base_url}", flush=True)
|
||||
print(f" Stall threshold: {STALL_THRESHOLD_S}s", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
if args.mlx_lm:
|
||||
start_mlx_lm()
|
||||
else:
|
||||
start_exo()
|
||||
|
||||
chat_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
10
uv.lock
generated
10
uv.lock
generated
@@ -378,7 +378,7 @@ dependencies = [
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260224+5289547a", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1025,7 +1025,7 @@ dependencies = [
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260224+5289547a", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1072,8 +1072,8 @@ cuda13 = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.7.dev20260220+13998a05"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }
|
||||
version = "0.30.7.dev20260224+5289547a"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
@@ -1108,7 +1108,7 @@ version = "0.30.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260224+5289547a", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
|
||||
Reference in New Issue
Block a user