diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..c6a5375df --- /dev/null +++ b/conftest.py @@ -0,0 +1 @@ +collect_ignore = ["tests/start_distributed_test.py"] diff --git a/e2e/conftest.py b/e2e/conftest.py index 3b6bd5b6b..aa11e827c 100644 --- a/e2e/conftest.py +++ b/e2e/conftest.py @@ -5,8 +5,8 @@ import json import os import sys from pathlib import Path -from urllib.request import urlopen, Request from urllib.error import URLError +from urllib.request import Request, urlopen E2E_DIR = Path(__file__).parent.resolve() TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120")) @@ -22,8 +22,10 @@ class Cluster: for path in overrides or []: compose_files.append(str(E2E_DIR / path)) self._compose_base = [ - "docker", "compose", - "-p", self.project, + "docker", + "compose", + "-p", + self.project, *[arg for f in compose_files for arg in ("-f", f)], ] @@ -35,7 +37,8 @@ class Cluster: async def _run(self, *args: str, check: bool = True) -> str: proc = await asyncio.create_subprocess_exec( - *self._compose_base, *args, + *self._compose_base, + *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, ) @@ -43,7 +46,9 @@ class Cluster: output = stdout.decode() if check and proc.returncode != 0: print(output, file=sys.stderr) - raise RuntimeError(f"docker compose {' '.join(args)} failed (rc={proc.returncode})") + raise RuntimeError( + f"docker compose {' '.join(args)} failed (rc={proc.returncode})" + ) return output async def build(self): @@ -61,17 +66,25 @@ class Cluster: async def logs(self) -> str: return await self._run("logs", check=False) - async def exec(self, service: str, *cmd: str, check: bool = True) -> tuple[int, str]: + async def exec( + self, service: str, *cmd: str, check: bool = True + ) -> tuple[int, str]: """Run a command inside a running container. Returns (returncode, output).""" proc = await asyncio.create_subprocess_exec( - *self._compose_base, "exec", "-T", service, *cmd, + *self._compose_base, + "exec", + "-T", + service, + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, ) stdout, _ = await proc.communicate() output = stdout.decode() if check and proc.returncode != 0: - raise RuntimeError(f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})") + raise RuntimeError( + f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})" + ) return proc.returncode, output async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT): @@ -114,13 +127,19 @@ class Cluster: await self.wait_for("Master election resolved", master_elected) await self.wait_for("API responding", api_responding) - async def _api(self, method: str, path: str, body: dict | None = None, timeout: int = 30) -> dict: + async def _api( + self, method: str, path: str, body: dict | None = None, timeout: int = 30 + ) -> dict: """Make an API request to the cluster. Returns parsed JSON.""" url = f"http://localhost:52415{path}" data = json.dumps(body).encode() if body else None - req = Request(url, data=data, headers={"Content-Type": "application/json"}, method=method) + req = Request( + url, data=data, headers={"Content-Type": "application/json"}, method=method + ) loop = asyncio.get_event_loop() - resp_bytes = await loop.run_in_executor(None, lambda: urlopen(req, timeout=timeout).read()) + resp_bytes = await loop.run_in_executor( + None, lambda: urlopen(req, timeout=timeout).read() + ) return json.loads(resp_bytes) async def place_model(self, model: str, timeout: int = 600): @@ -136,7 +155,9 @@ class Cluster: await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout) - async def chat(self, model: str, messages: list[dict], timeout: int = 600, **kwargs) -> dict: + async def chat( + self, model: str, messages: list[dict], timeout: int = 600, **kwargs + ) -> dict: """Send a chat completion request. Retries until model is downloaded and inference completes.""" body = json.dumps({"model": model, "messages": messages, **kwargs}).encode() deadline = asyncio.get_event_loop().time() + timeout @@ -150,7 +171,9 @@ class Cluster: headers={"Content-Type": "application/json"}, ) loop = asyncio.get_event_loop() - resp_bytes = await loop.run_in_executor(None, lambda: urlopen(req, timeout=300).read()) + resp_bytes = await loop.run_in_executor( + None, lambda r=req: urlopen(r, timeout=300).read() + ) return json.loads(resp_bytes) except Exception as e: last_error = e diff --git a/e2e/run_all.py b/e2e/run_all.py index abff359d3..0ab2aafac 100644 --- a/e2e/run_all.py +++ b/e2e/run_all.py @@ -22,7 +22,9 @@ def is_slow(test_file: Path) -> bool: if line.strip().startswith('"""') or line.strip().startswith("'''"): # Read into the docstring for doc_line in f: - if "slow" in doc_line.lower() and doc_line.strip().startswith("slow"): + if "slow" in doc_line.lower() and doc_line.strip().startswith( + "slow" + ): return True if '"""' in doc_line or "'''" in doc_line: break @@ -60,7 +62,9 @@ def main(): total = passed + failed + skipped print("================================") - print(f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")) + print( + f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "") + ) if failed: print(f"Failed: {' '.join(failures)}") diff --git a/e2e/test_cluster_formation.py b/e2e/test_cluster_formation.py index 444830d5b..3bb528a7a 100644 --- a/e2e/test_cluster_formation.py +++ b/e2e/test_cluster_formation.py @@ -5,6 +5,7 @@ Verifies two nodes discover each other, elect a master, and the API responds. import asyncio import sys + sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent)) from conftest import Cluster diff --git a/e2e/test_inference_snapshot.py b/e2e/test_inference_snapshot.py index 796282109..79d88e8f3 100644 --- a/e2e/test_inference_snapshot.py +++ b/e2e/test_inference_snapshot.py @@ -13,6 +13,7 @@ import asyncio import json import sys from pathlib import Path + sys.path.insert(0, str(Path(__file__).parent)) from conftest import Cluster @@ -55,17 +56,23 @@ async def main(): f" Got: {content!r}\n" f" Delete {SNAPSHOT_FILE} to regenerate." ) - print(f" Output matches snapshot") + print(" Output matches snapshot") else: SNAPSHOT_FILE.parent.mkdir(parents=True, exist_ok=True) - SNAPSHOT_FILE.write_text(json.dumps({ - "model": MODEL, - "seed": SEED, - "temperature": 0, - "prompt": PROMPT, - "max_tokens": MAX_TOKENS, - "content": content, - }, indent=2) + "\n") + SNAPSHOT_FILE.write_text( + json.dumps( + { + "model": MODEL, + "seed": SEED, + "temperature": 0, + "prompt": PROMPT, + "max_tokens": MAX_TOKENS, + "content": content, + }, + indent=2, + ) + + "\n" + ) print(f" Snapshot created: {SNAPSHOT_FILE}") print("PASSED: inference_snapshot") diff --git a/e2e/test_no_internet.py b/e2e/test_no_internet.py index fa03c220d..602518254 100644 --- a/e2e/test_no_internet.py +++ b/e2e/test_no_internet.py @@ -7,6 +7,7 @@ except private subnets and multicast (for mDNS discovery). import asyncio import sys + sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent)) from conftest import Cluster @@ -22,7 +23,15 @@ async def main(): # Verify internet is actually blocked from inside the containers for node in ["exo-node-1", "exo-node-2"]: - rc, _ = await cluster.exec(node, "curl", "-sf", "--max-time", "3", "https://huggingface.co", check=False) + rc, _ = await cluster.exec( + node, + "curl", + "-sf", + "--max-time", + "3", + "https://huggingface.co", + check=False, + ) assert rc != 0, f"{node} should not be able to reach the internet" print(f" {node}: internet correctly blocked") diff --git a/e2e/test_runner_chaos.py b/e2e/test_runner_chaos.py index 0052643d0..40922b1ad 100644 --- a/e2e/test_runner_chaos.py +++ b/e2e/test_runner_chaos.py @@ -13,6 +13,7 @@ import asyncio import contextlib import sys from pathlib import Path + sys.path.insert(0, str(Path(__file__).parent)) from conftest import Cluster @@ -54,8 +55,7 @@ async def main(): # Verify RunnerFailed was emitted (visible in logs) log = await cluster.logs() assert "runner process died unexpectedly" in log, ( - "Expected health check to detect runner death but it didn't.\n" - f"Logs:\n{log}" + f"Expected health check to detect runner death but it didn't.\nLogs:\n{log}" ) print("PASSED: runner_chaos")