diff --git a/tests/models/flan_t5/test_modeling_flan_t5.py b/tests/models/flan_t5/test_modeling_flan_t5.py index b10536c7..5e6ba33c 100644 --- a/tests/models/flan_t5/test_modeling_flan_t5.py +++ b/tests/models/flan_t5/test_modeling_flan_t5.py @@ -32,7 +32,7 @@ def flan_t5_id() -> str: def test_small_flan(qa_prompt: str, flan_t5_id: str): - llm = openllm.AutoLLM.for_model("flan-t5", model_id=flan_t5_id) + llm = openllm.AutoLLM.for_model("flan-t5", model_id=flan_t5_id, ensure_available=True) generate = llm(qa_prompt) assert generate diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 3320203d..8a70e805 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -30,7 +30,7 @@ def opt_id() -> str: def test_small_opt(qa_prompt: str, opt_id: str): - llm = openllm.AutoLLM.for_model("opt", model_id=opt_id) + llm = openllm.AutoLLM.for_model("opt", model_id=opt_id, ensure_available=True) generate = llm(qa_prompt) assert generate