diff --git a/worker/tests/test_multimodel/test_inference_llama70B.py b/worker/tests/test_multimodel/test_inference_llama70B.py index 560faa47..71a67df5 100644 --- a/worker/tests/test_multimodel/test_inference_llama70B.py +++ b/worker/tests/test_multimodel/test_inference_llama70B.py @@ -1,7 +1,7 @@ import asyncio +import os from logging import Logger from typing import Callable -import os import pytest @@ -151,7 +151,13 @@ async def test_2_runner_inference( - +@pytest.mark.skipif( + not ( + os.path.exists(os.path.expanduser("~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/")) + and _get_model_size_gb(os.path.expanduser("~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/")) > 30 + ), + reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded" +) async def test_parallel_inference( logger: Logger, pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],