fix: nightly resolver for correct tag (#177)

This commit is contained in:
Aaron Pham
2023-08-02 13:10:50 -04:00
committed by GitHub
parent d4fbfa5e5c
commit 72337410cf
3 changed files with 20 additions and 10 deletions

View File

@@ -22,7 +22,7 @@ import typing as t
from . import oci as oci
from ..utils import LazyModule
_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": oci.__all__,}
_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": oci.__all__}
if t.TYPE_CHECKING:
from . import _package as _package
@@ -31,6 +31,7 @@ if t.TYPE_CHECKING:
from ._package import construct_python_options as construct_python_options
from ._package import create_bento as create_bento
from .oci import CONTAINER_NAMES as CONTAINER_NAMES
from .oci import RefResolver as RefResolver
from .oci import build_container as build_container
from .oci import get_base_container_name as get_base_container_name
from .oci import get_base_container_tag as get_base_container_tag

View File

@@ -19,6 +19,7 @@ import logging
import pathlib
import shutil
import subprocess
import tempfile
import typing as t
import attr
@@ -75,7 +76,7 @@ class VersionNotSupported(OpenLLMException):
_RefTuple: type[RefTuple] = make_attr_tuple_class("_RefTuple", ["git_hash", "version", "strategy"])
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
class Ref:
class RefResolver:
"""TODO: Support offline mode.
Maybe we need to save git hash when building the Bento.
@@ -86,7 +87,15 @@ class Ref:
_git: git.cmd.Git = git.cmd.Git(_URI) # TODO: support offline mode
@classmethod
def _nightly_ref(cls) -> RefTuple: return _RefTuple((*cls._git.ls_remote(_URI, "main", heads=True).split(), "nightly"))
@functools.lru_cache
def nightly_resolver(cls) -> str:
# Will do a clone bare to tempdir, and return the latest commit hash that we build the base image
# NOTE: this is a bit expensive, but it is ok since we only run this during build
with tempfile.TemporaryDirectory(prefix="openllm-bare-") as tempdir:
cls._git.clone(_URI, tempdir, bare=True, depth=1)
return next(it.hexsha for it in git.Repo(tempdir).iter_commits("main", max_count=10) if "[skip ci]" not in str(it.summary))
@classmethod
def _nightly_ref(cls) -> RefTuple: return _RefTuple((cls.nightly_resolver(), "refs/heads/main", "nightly"))
@classmethod
def _release_ref(cls, version_str: str | None = None) -> RefTuple:
_use_base_strategy = version_str is None
@@ -101,7 +110,7 @@ class Ref:
if VersionInfo.from_version_string(version_str) < (0, 2, 12): raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'")
return _RefTuple((*version, "release" if _use_base_strategy else "custom"))
@classmethod
def from_strategy(cls, strategy_or_version: t.Literal["release", "nightly"] | str | None = None) -> Ref:
def from_strategy(cls, strategy_or_version: t.Literal["release", "nightly"] | str | None = None) -> RefResolver:
if strategy_or_version is None or strategy_or_version == "release":
logger.debug("Using default strategy 'release' for resolving base image version.")
return cls(*cls._release_ref())
@@ -114,12 +123,13 @@ class Ref:
return cls(*cls._release_ref(version_str=strategy_or_version))
@property
def tag(self) -> str:
# NOTE: latest tag can also be nightly, but discouraged to use it. For nightly refer to use sha-<git_hash_short>
if self.strategy == "latest": return "latest"
elif self.strategy == "nightly": return f"sha-{self.git_hash[:7]}"
else: return repr(self.version)
@functools.lru_cache(maxsize=256)
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str: return Ref.from_strategy(strategy).tag
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str: return RefResolver.from_strategy(strategy).tag
def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None, version_strategy: LiteralContainerVersionStrategy = "release", push: bool = False, machine: bool = False) -> dict[str | LiteralContainerRegistry, str]:
"""This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed.
@@ -150,10 +160,10 @@ def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralCon
if t.TYPE_CHECKING:
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
supported_registries: list[str]
__all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries"]
def __dir__() -> list[str]:
return sorted(__all__)
__all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"]
def __dir__() -> list[str]: return sorted(__all__)
def __getattr__(name: str) -> t.Any:
if name == "supported_registries": return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()

View File

@@ -25,8 +25,7 @@ import openllm
from .. import termui
from ..._prompt import process_prompt
if t.TYPE_CHECKING:
from ..entrypoint import LiteralOutput
LiteralOutput = t.Literal["json", "pretty", "porcelain"]
@click.command("get_prompt", context_settings=termui.CONTEXT_SETTINGS)
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))