From 72337410cfebcd5f348dd4d5d1ac62e93966bf08 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Wed, 2 Aug 2023 13:10:50 -0400 Subject: [PATCH] fix: nightly resolver for correct tag (#177) --- src/openllm/bundle/__init__.py | 3 ++- src/openllm/bundle/oci/__init__.py | 24 +++++++++++++++++------- src/openllm/cli/ext/get_prompt.py | 3 +-- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/openllm/bundle/__init__.py b/src/openllm/bundle/__init__.py index 2b5d7fc1..70350bc6 100644 --- a/src/openllm/bundle/__init__.py +++ b/src/openllm/bundle/__init__.py @@ -22,7 +22,7 @@ import typing as t from . import oci as oci from ..utils import LazyModule -_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": oci.__all__,} +_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": oci.__all__} if t.TYPE_CHECKING: from . import _package as _package @@ -31,6 +31,7 @@ if t.TYPE_CHECKING: from ._package import construct_python_options as construct_python_options from ._package import create_bento as create_bento from .oci import CONTAINER_NAMES as CONTAINER_NAMES + from .oci import RefResolver as RefResolver from .oci import build_container as build_container from .oci import get_base_container_name as get_base_container_name from .oci import get_base_container_tag as get_base_container_tag diff --git a/src/openllm/bundle/oci/__init__.py b/src/openllm/bundle/oci/__init__.py index 06250057..789218a1 100644 --- a/src/openllm/bundle/oci/__init__.py +++ b/src/openllm/bundle/oci/__init__.py @@ -19,6 +19,7 @@ import logging import pathlib import shutil import subprocess +import tempfile import typing as t import attr @@ -75,7 +76,7 @@ class VersionNotSupported(OpenLLMException): _RefTuple: type[RefTuple] = make_attr_tuple_class("_RefTuple", ["git_hash", "version", "strategy"]) @attr.attrs(eq=False, order=False, slots=True, frozen=True) -class Ref: +class RefResolver: """TODO: Support offline mode. Maybe we need to save git hash when building the Bento. @@ -86,7 +87,15 @@ class Ref: _git: git.cmd.Git = git.cmd.Git(_URI) # TODO: support offline mode @classmethod - def _nightly_ref(cls) -> RefTuple: return _RefTuple((*cls._git.ls_remote(_URI, "main", heads=True).split(), "nightly")) + @functools.lru_cache + def nightly_resolver(cls) -> str: + # Will do a clone bare to tempdir, and return the latest commit hash that we build the base image + # NOTE: this is a bit expensive, but it is ok since we only run this during build + with tempfile.TemporaryDirectory(prefix="openllm-bare-") as tempdir: + cls._git.clone(_URI, tempdir, bare=True, depth=1) + return next(it.hexsha for it in git.Repo(tempdir).iter_commits("main", max_count=10) if "[skip ci]" not in str(it.summary)) + @classmethod + def _nightly_ref(cls) -> RefTuple: return _RefTuple((cls.nightly_resolver(), "refs/heads/main", "nightly")) @classmethod def _release_ref(cls, version_str: str | None = None) -> RefTuple: _use_base_strategy = version_str is None @@ -101,7 +110,7 @@ class Ref: if VersionInfo.from_version_string(version_str) < (0, 2, 12): raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'") return _RefTuple((*version, "release" if _use_base_strategy else "custom")) @classmethod - def from_strategy(cls, strategy_or_version: t.Literal["release", "nightly"] | str | None = None) -> Ref: + def from_strategy(cls, strategy_or_version: t.Literal["release", "nightly"] | str | None = None) -> RefResolver: if strategy_or_version is None or strategy_or_version == "release": logger.debug("Using default strategy 'release' for resolving base image version.") return cls(*cls._release_ref()) @@ -114,12 +123,13 @@ class Ref: return cls(*cls._release_ref(version_str=strategy_or_version)) @property def tag(self) -> str: + # NOTE: latest tag can also be nightly, but discouraged to use it. For nightly refer to use sha- if self.strategy == "latest": return "latest" elif self.strategy == "nightly": return f"sha-{self.git_hash[:7]}" else: return repr(self.version) @functools.lru_cache(maxsize=256) -def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str: return Ref.from_strategy(strategy).tag +def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str: return RefResolver.from_strategy(strategy).tag def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None, version_strategy: LiteralContainerVersionStrategy = "release", push: bool = False, machine: bool = False) -> dict[str | LiteralContainerRegistry, str]: """This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed. @@ -150,10 +160,10 @@ def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralCon if t.TYPE_CHECKING: CONTAINER_NAMES: dict[LiteralContainerRegistry, str] supported_registries: list[str] -__all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries"] -def __dir__() -> list[str]: - return sorted(__all__) +__all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"] + +def __dir__() -> list[str]: return sorted(__all__) def __getattr__(name: str) -> t.Any: if name == "supported_registries": return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))() diff --git a/src/openllm/cli/ext/get_prompt.py b/src/openllm/cli/ext/get_prompt.py index 8ad89a17..035931ac 100644 --- a/src/openllm/cli/ext/get_prompt.py +++ b/src/openllm/cli/ext/get_prompt.py @@ -25,8 +25,7 @@ import openllm from .. import termui from ..._prompt import process_prompt -if t.TYPE_CHECKING: - from ..entrypoint import LiteralOutput +LiteralOutput = t.Literal["json", "pretty", "porcelain"] @click.command("get_prompt", context_settings=termui.CONTEXT_SETTINGS) @click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))