From a7adfa4712233173e02ee7dded41ddc4e23335b2 Mon Sep 17 00:00:00 2001 From: bojiang Date: Wed, 3 Jul 2024 17:32:21 +0800 Subject: [PATCH] fix: venv --- openllm_next/common.py | 2 +- openllm_next/venv.py | 28 ++++++++++++++++------------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/openllm_next/common.py b/openllm_next/common.py index 6d832bdf..ad7a5b59 100644 --- a/openllm_next/common.py +++ b/openllm_next/common.py @@ -246,7 +246,7 @@ class VenvSpec(SimpleNamespace): def __hash__(self): return md5( # self.python_version, - *sorted(self.python_packages.values()), + *sorted(self.python_packages), ) diff --git a/openllm_next/venv.py b/openllm_next/venv.py index 4eb477bc..0e7e5124 100644 --- a/openllm_next/venv.py +++ b/openllm_next/venv.py @@ -17,15 +17,22 @@ from openllm_next.common import ( ) -def _resolve_packages(requirement: typing.Union[pathlib.Path, str]) -> dict[str, str]: +@functools.lru_cache +def _resolve_packages(requirement: typing.Union[pathlib.Path, str]): from pip_requirements_parser import RequirementsFile requirements_txt = RequirementsFile.from_file( str(requirement), include_nested=True, ) - deps: dict[str, str] = {} - for req in requirements_txt.requirements: + return requirements_txt.requirements + + +def _filter_preheat_packages(requirements: Iterable) -> list[str]: + PREHEAT_PIP_PACKAGES = ["torch", "vllm"] + + deps: list[str] = [] + for req in requirements: if ( req.is_editable or req.is_local_path @@ -36,9 +43,9 @@ def _resolve_packages(requirement: typing.Union[pathlib.Path, str]) -> dict[str, ): continue for sp in req.specifier: - if sp.operator == "==": + if sp.operator == "==" and req.name in PREHEAT_PIP_PACKAGES: assert req.line is not None - deps[req.name] = req.line + deps.append(req.line) break return deps @@ -52,11 +59,8 @@ def _resolve_bento_env_specs(bento: BentoInfo): if not lock_file.exists(): lock_file = bento.path / "env" / "python" / "requirements.txt" - python_packages = _resolve_packages(lock_file) - PREHEAT_PIP_PACKAGES = ["torch", "vllm"] - preheat_packages = { - k: v for k, v in python_packages.items() if k in PREHEAT_PIP_PACKAGES - } + reqs = _resolve_packages(lock_file) + preheat_packages = _filter_preheat_packages(reqs) ver = ver_file.read_text().strip() return ( VenvSpec( @@ -66,7 +70,7 @@ def _resolve_bento_env_specs(bento: BentoInfo): ), VenvSpec( python_version=ver, - python_packages=python_packages, + python_packages=[v.line for v in reqs], name_prefix=f"{bento.tag.replace(':', '_')}-2-", ), ) @@ -105,7 +109,7 @@ def _ensure_venv( with open(lib_dir / f"{parrent_venv.name}.pth", "w+") as f: f.write(str(parent_lib_dir)) with open(venv / "requirements.txt", "w") as f: - f.write("\n".join(sorted(env_spec.python_packages.values()))) + f.write("\n".join(sorted(env_spec.python_packages))) run_command( [ "python",