"
+
+ def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]:
+ return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
+
+ if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
+ else: doc = func.__doc__
+ return t.cast(_T, functools.update_wrapper(types.new_class(name, (PartialAny, ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,))
diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py
index 3fa3c4b7..e685636d 100644
--- a/src/openllm/utils/dantic.py
+++ b/src/openllm/utils/dantic.py
@@ -31,478 +31,432 @@ from click import shell_completion as sc
from click import types as click_types
if t.TYPE_CHECKING:
- from attr import _ValidatorType
+ from attr import _ValidatorType
- from .._types import ListAny
+ from .._types import ListAny
_T = t.TypeVar("_T")
AnyCallable = t.Callable[..., t.Any]
FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command])
+def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: type[t.Any] | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]:
+ # TODO: support parsing nested attrs class and Union
+ envvar = field.metadata["env"]
+ dasherized = inflection.dasherize(name)
+ underscored = inflection.underscore(name)
-def attrs_to_options(
- name: str,
- field: attr.Attribute[t.Any],
- model_name: str,
- typ: type[t.Any] | None = None,
- suffix_generation: bool = False,
- suffix_sampling: bool = False,
-) -> t.Callable[[FC], FC]:
- # TODO: support parsing nested attrs class and Union
- envvar = field.metadata["env"]
- dasherized = inflection.dasherize(name)
- underscored = inflection.underscore(name)
+ if typ in (None, attr.NOTHING):
+ typ = field.type
+ if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
- if typ in (None, attr.NOTHING):
- typ = field.type
- if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
-
- full_option_name = f"--{dasherized}"
- if field.type is bool: full_option_name += f"/--no-{dasherized}"
- if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
- elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
- else: identifier = f"{model_name}_{underscored}"
-
- return cog.optgroup.option(
- identifier,
- full_option_name,
- type=parse_type(typ),
- required=field.default is attr.NOTHING,
- default=field.default if field.default not in (attr.NOTHING, None) else None,
- show_default=True,
- multiple=allows_multiple(typ) if typ else False,
- help=field.metadata.get("description", "(No description provided)"),
- show_envvar=True,
- envvar=envvar,
- )
+ full_option_name = f"--{dasherized}"
+ if field.type is bool: full_option_name += f"/--no-{dasherized}"
+ if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
+ elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
+ else: identifier = f"{model_name}_{underscored}"
+ return cog.optgroup.option(identifier, full_option_name, type=parse_type(typ), required=field.default is attr.NOTHING, default=field.default if field.default not in (attr.NOTHING, None) else None, show_default=True, multiple=allows_multiple(typ) if typ else False, help=field.metadata.get("description", "(No description provided)"), show_envvar=True, envvar=envvar,)
def env_converter(value: t.Any, env: str | None = None) -> t.Any:
- if env is not None:
- value = os.environ.get(env, value)
- if value is not None and isinstance(value, str):
- try:
- return orjson.loads(value.lower())
- except orjson.JSONDecodeError as err:
- raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
- return value
+ if env is not None:
+ value = os.environ.get(env, value)
+ if value is not None and isinstance(value, str):
+ try:
+ return orjson.loads(value.lower())
+ except orjson.JSONDecodeError as err:
+ raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
+ return value
+def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[_T] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any,) -> t.Any:
+ """A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
-def Field(
- default: t.Any = None,
- *,
- ge: int | float | None = None,
- le: int | float | None = None,
- validator: _ValidatorType[_T] | None = None,
- description: str | None = None,
- env: str | None = None,
- auto_default: bool = False,
- use_default_converter: bool = True,
- **attrs: t.Any,
-) -> t.Any:
- """A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
+ By default, if both validator and ge are provided, then then ge will be
+ piped into first, then all of the other validator will be run afterwards.
- By default, if both validator and ge are provided, then then ge will be
- piped into first, then all of the other validator will be run afterwards.
+ Args:
+ default: The default value for ``dantic.Field``. Defaults to ``None``.
+ ge: Greater than or equal to. Defaults to None.
+ le: Less than or equal to. Defaults to None.
+ validator: Optional attrs-compatible validators type. Default to None
+ description: the documentation for the field. Defaults to None.
+ env: the environment variable to read from. Defaults to None.
+ auto_default: a bool indicating whether to use the default value as the environment.
+ Defaults to False. If set to True, the behaviour of this Field will also depends
+ on kw_only. If kw_only=True, the this field will become 'Required' and the default
+ value is omitted. If kw_only=False, then the default value will be used as before.
+ use_default_converter: a bool indicating whether to use the default converter. Defaults
+ to True. If set to False, then the default converter will not be used.
+ The default converter converts a given value from the environment variable
+ for this given Field.
+ **attrs: The rest of the arguments are passed to attr.field
+ """
+ metadata = attrs.pop("metadata", {})
+ if description is None:
+ description = "(No description provided)"
+ metadata["description"] = description
+ if env is not None:
+ metadata["env"] = env
+ piped: list[_ValidatorType[t.Any]] = []
- Args:
- default: The default value for ``dantic.Field``. Defaults to ``None``.
- ge: Greater than or equal to. Defaults to None.
- le: Less than or equal to. Defaults to None.
- validator: Optional attrs-compatible validators type. Default to None
- description: the documentation for the field. Defaults to None.
- env: the environment variable to read from. Defaults to None.
- auto_default: a bool indicating whether to use the default value as the environment.
- Defaults to False. If set to True, the behaviour of this Field will also depends
- on kw_only. If kw_only=True, the this field will become 'Required' and the default
- value is omitted. If kw_only=False, then the default value will be used as before.
- use_default_converter: a bool indicating whether to use the default converter. Defaults
- to True. If set to False, then the default converter will not be used.
- The default converter converts a given value from the environment variable
- for this given Field.
- **attrs: The rest of the arguments are passed to attr.field
- """
- metadata = attrs.pop("metadata", {})
- if description is None:
- description = "(No description provided)"
- metadata["description"] = description
- if env is not None:
- metadata["env"] = env
- piped: list[_ValidatorType[t.Any]] = []
+ converter = attrs.pop("converter", None)
+ if use_default_converter:
+ converter = functools.partial(env_converter, env=env)
- converter = attrs.pop("converter", None)
- if use_default_converter:
- converter = functools.partial(env_converter, env=env)
+ if ge is not None:
+ piped.append(attr.validators.ge(ge))
+ if le is not None:
+ piped.append(attr.validators.le(le))
+ if validator is not None:
+ piped.append(validator)
- if ge is not None:
- piped.append(attr.validators.ge(ge))
- if le is not None:
- piped.append(attr.validators.le(le))
- if validator is not None:
- piped.append(validator)
+ if len(piped) == 0:
+ _validator = None
+ elif len(piped) == 1:
+ _validator = piped[0]
+ else:
+ _validator = attr.validators.and_(*piped)
- if len(piped) == 0:
- _validator = None
- elif len(piped) == 1:
- _validator = piped[0]
- else:
- _validator = attr.validators.and_(*piped)
+ factory = attrs.pop("factory", None)
+ if factory is not None and default is not None:
+ raise RuntimeError("'factory' and 'default' are mutually exclusive.")
+ # NOTE: the behaviour of this is we will respect factory over the default
+ if factory is not None:
+ attrs["factory"] = factory
+ else:
+ attrs["default"] = default
- factory = attrs.pop("factory", None)
- if factory is not None and default is not None:
- raise RuntimeError("'factory' and 'default' are mutually exclusive.")
- # NOTE: the behaviour of this is we will respect factory over the default
- if factory is not None:
- attrs["factory"] = factory
- else:
- attrs["default"] = default
-
- kw_only = attrs.pop("kw_only", False)
- if auto_default and kw_only:
- attrs.pop("default")
-
- return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
+ kw_only = attrs.pop("kw_only", False)
+ if auto_default and kw_only:
+ attrs.pop("default")
+ return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType]:
- """Transforms the pydantic field's type into a click-compatible type.
+ """Transforms the pydantic field's type into a click-compatible type.
- Args:
- field_type: pydantic field type
+ Args:
+ field_type: pydantic field type
- Returns:
- ParamType: click type equivalent
- """
- from . import lenient_issubclass
-
- if t.get_origin(field_type) is t.Union:
- raise NotImplementedError("Unions are not supported")
- # enumeration strings or other Enum derivatives
- if lenient_issubclass(field_type, Enum):
- return EnumChoice(enum=field_type, case_sensitive=True)
- # literals are enum-like with way less functionality
- if is_literal(field_type):
- return LiteralChoice(value=field_type, case_sensitive=True)
- # modules, classes, functions
- if is_typing(field_type):
- return ModuleType()
- # entire dictionaries:
- # using a Dict, convert in advance
- if is_mapping(field_type):
- return JsonType()
- # list, List[p], Tuple[p], Set[p] and so on
- if is_container(field_type):
- return parse_container_args(field_type)
- # bytes are not natively supported by click
- if lenient_issubclass(field_type, bytes):
- return BytesType()
- # return the current type: it should be a primitive
- return field_type
+ Returns:
+ ParamType: click type equivalent
+ """
+ from . import lenient_issubclass
+ if t.get_origin(field_type) is t.Union:
+ raise NotImplementedError("Unions are not supported")
+ # enumeration strings or other Enum derivatives
+ if lenient_issubclass(field_type, Enum):
+ return EnumChoice(enum=field_type, case_sensitive=True)
+ # literals are enum-like with way less functionality
+ if is_literal(field_type):
+ return LiteralChoice(value=field_type, case_sensitive=True)
+ # modules, classes, functions
+ if is_typing(field_type):
+ return ModuleType()
+ # entire dictionaries:
+ # using a Dict, convert in advance
+ if is_mapping(field_type):
+ return JsonType()
+ # list, List[p], Tuple[p], Set[p] and so on
+ if is_container(field_type):
+ return parse_container_args(field_type)
+ # bytes are not natively supported by click
+ if lenient_issubclass(field_type, bytes):
+ return BytesType()
+ # return the current type: it should be a primitive
+ return field_type
def is_typing(field_type: type) -> bool:
- """Checks whether the current type is a module-like type.
+ """Checks whether the current type is a module-like type.
- Args:
- field_type: pydantic field type
+ Args:
+ field_type: pydantic field type
- Returns:
- bool: true if the type is itself a type
- """
- raw = t.get_origin(field_type)
- if raw is None:
- return False
- if raw is type or raw is t.Type:
- return True
+ Returns:
+ bool: true if the type is itself a type
+ """
+ raw = t.get_origin(field_type)
+ if raw is None:
return False
-
+ if raw is type or raw is t.Type:
+ return True
+ return False
def is_literal(field_type: type) -> bool:
- """Checks whether the given field type is a Literal type or not.
+ """Checks whether the given field type is a Literal type or not.
- Literals are weird: isinstance and subclass do not work, so you compare
- the origin with the Literal declaration itself.
+ Literals are weird: isinstance and subclass do not work, so you compare
+ the origin with the Literal declaration itself.
- Args:
- field_type: current pydantic type
-
- Returns:
- bool: true if Literal type, false otherwise
- """
- origin = t.get_origin(field_type)
- return origin is not None and origin is t.Literal
+ Args:
+ field_type: current pydantic type
+ Returns:
+ bool: true if Literal type, false otherwise
+ """
+ origin = t.get_origin(field_type)
+ return origin is not None and origin is t.Literal
class ModuleType(ParamType):
- name = "module"
+ name = "module"
- def _import_object(self, value: str) -> t.Any:
- module_name, class_name = value.rsplit(".", maxsplit=1)
- if not all(s.isidentifier() for s in module_name.split(".")):
- raise ValueError(f"'{value}' is not a valid module name")
- if not class_name.isidentifier():
- raise ValueError(f"Variable '{class_name}' is not a valid identifier")
+ def _import_object(self, value: str) -> t.Any:
+ module_name, class_name = value.rsplit(".", maxsplit=1)
+ if not all(s.isidentifier() for s in module_name.split(".")):
+ raise ValueError(f"'{value}' is not a valid module name")
+ if not class_name.isidentifier():
+ raise ValueError(f"Variable '{class_name}' is not a valid identifier")
- module = importlib.import_module(module_name)
- if class_name:
- try:
- return getattr(module, class_name)
- except AttributeError:
- raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
-
- def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
- try:
- if isinstance(value, str):
- return self._import_object(value)
- return value
- except Exception as exc:
- self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
+ module = importlib.import_module(module_name)
+ if class_name:
+ try:
+ return getattr(module, class_name)
+ except AttributeError:
+ raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
+ def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
+ try:
+ if isinstance(value, str):
+ return self._import_object(value)
+ return value
+ except Exception as exc:
+ self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
class EnumChoice(click.Choice):
- name = "enum"
+ name = "enum"
- def __init__(self, enum: Enum, case_sensitive: bool = False):
- """Enum type support for click that extends ``click.Choice``.
+ def __init__(self, enum: Enum, case_sensitive: bool = False):
+ """Enum type support for click that extends ``click.Choice``.
- Args:
- enum: Given enum
- case_sensitive: Whether this choice should be case case_sensitive.
- """
- self.mapping = enum
- self.internal_type = type(enum)
- choices: ListAny = [e.name for e in enum.__class__]
- super().__init__(choices, case_sensitive)
-
- def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
- if isinstance(value, self.internal_type):
- return value
- result = super().convert(value, param, ctx)
- if isinstance(result, str):
- result = self.internal_type[result]
- return result
+ Args:
+ enum: Given enum
+ case_sensitive: Whether this choice should be case case_sensitive.
+ """
+ self.mapping = enum
+ self.internal_type = type(enum)
+ choices: ListAny = [e.name for e in enum.__class__]
+ super().__init__(choices, case_sensitive)
+ def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
+ if isinstance(value, self.internal_type):
+ return value
+ result = super().convert(value, param, ctx)
+ if isinstance(result, str):
+ result = self.internal_type[result]
+ return result
class LiteralChoice(EnumChoice):
- name = "literal"
-
- def __init__(self, value: t.Any, case_sensitive: bool = False):
- """Literal support for click."""
- # expect every literal value to belong to the same primitive type
- values = list(value.__args__)
- item_type = type(values[0])
- if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
- _mapping = {str(v): v for v in values}
- super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
- self.internal_type = item_type
+ name = "literal"
+ def __init__(self, value: t.Any, case_sensitive: bool = False):
+ """Literal support for click."""
+ # expect every literal value to belong to the same primitive type
+ values = list(value.__args__)
+ item_type = type(values[0])
+ if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
+ _mapping = {str(v): v for v in values}
+ super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
+ self.internal_type = item_type
def allows_multiple(field_type: type[t.Any]) -> bool:
- """Checks whether the current type allows for multiple arguments to be provided as input or not.
+ """Checks whether the current type allows for multiple arguments to be provided as input or not.
- For containers, it exploits click's support for lists and such to use the same option multiple times
- to create a complex object: `python run.py --subsets train --subsets test`
- # becomes `subsets: ["train", "test"]`.
+ For containers, it exploits click's support for lists and such to use the same option multiple times
+ to create a complex object: `python run.py --subsets train --subsets test`
+ # becomes `subsets: ["train", "test"]`.
- Args:
- field_type: pydantic type.
+ Args:
+ field_type: pydantic type.
- Returns:
- bool: true if it's a composite field (lists, containers and so on), false otherwise
- """
- # Early out for mappings, since it's better to deal with them using strings.
- if is_mapping(field_type):
- return False
- # Activate multiple option for (simple) container types
- if is_container(field_type):
- args = parse_container_args(field_type)
- # A non-composite type has a single argument, such as 'List[int]'
- # A composite type has a tuple of arguments, like 'Tuple[str, int, int]'.
- # For the moment, only non-composite types are allowed.
- return not isinstance(args, tuple)
+ Returns:
+ bool: true if it's a composite field (lists, containers and so on), false otherwise
+ """
+ # Early out for mappings, since it's better to deal with them using strings.
+ if is_mapping(field_type):
return False
-
+ # Activate multiple option for (simple) container types
+ if is_container(field_type):
+ args = parse_container_args(field_type)
+ # A non-composite type has a single argument, such as 'List[int]'
+ # A composite type has a tuple of arguments, like 'Tuple[str, int, int]'.
+ # For the moment, only non-composite types are allowed.
+ return not isinstance(args, tuple)
+ return False
def is_mapping(field_type: type) -> bool:
- """Checks whether this field represents a dictionary or JSON object.
+ """Checks whether this field represents a dictionary or JSON object.
- Args:
- field_type (type): pydantic type
-
- Returns:
- bool: true when the field is a dict-like object, false otherwise.
- """
- # Early out for standard containers.
- from . import lenient_issubclass
- if lenient_issubclass(field_type, t.Mapping): return True
- # for everything else or when the typing is more complex, check its origin
- origin = t.get_origin(field_type)
- if origin is None: return False
- return lenient_issubclass(origin, t.Mapping)
+ Args:
+ field_type (type): pydantic type
+ Returns:
+ bool: true when the field is a dict-like object, false otherwise.
+ """
+ # Early out for standard containers.
+ from . import lenient_issubclass
+ if lenient_issubclass(field_type, t.Mapping): return True
+ # for everything else or when the typing is more complex, check its origin
+ origin = t.get_origin(field_type)
+ if origin is None: return False
+ return lenient_issubclass(origin, t.Mapping)
def is_container(field_type: type) -> bool:
- """Checks whether the current type is a container type ('contains' other types), like lists and tuples.
+ """Checks whether the current type is a container type ('contains' other types), like lists and tuples.
- Args:
- field_type: pydantic field type
-
- Returns:
- bool: true if a container, false otherwise
- """
- # do not consider strings or byte arrays as containers
- if field_type in (str, bytes): return False
- # Early out for standard containers: list, tuple, range
- from . import lenient_issubclass
- if lenient_issubclass(field_type, t.Container): return True
- origin = t.get_origin(field_type)
- # Early out for non-typing objects
- if origin is None: return False
- return lenient_issubclass(origin, t.Container)
+ Args:
+ field_type: pydantic field type
+ Returns:
+ bool: true if a container, false otherwise
+ """
+ # do not consider strings or byte arrays as containers
+ if field_type in (str, bytes): return False
+ # Early out for standard containers: list, tuple, range
+ from . import lenient_issubclass
+ if lenient_issubclass(field_type, t.Container): return True
+ origin = t.get_origin(field_type)
+ # Early out for non-typing objects
+ if origin is None: return False
+ return lenient_issubclass(origin, t.Container)
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType]:
- """Parses the arguments inside a container type (lists, tuples and so on).
+ """Parses the arguments inside a container type (lists, tuples and so on).
- Args:
- field_type: pydantic field type
-
- Returns:
- ParamType | tuple[ParamType]: single click-compatible type or a tuple
- """
- if not is_container(field_type):
- raise ValueError("Field type is not a container type.")
- args = t.get_args(field_type)
- # Early out for untyped containers: standard lists, tuples, List[Any]
- # Use strings when the type is unknown, avoid click's type guessing
- if len(args) == 0:
- return click_types.convert_type(str)
- # Early out for homogenous containers: Tuple[int], List[str]
- if len(args) == 1:
- return parse_single_arg(args[0])
- # Early out for homogenous tuples of indefinite length: Tuple[int, ...]
- if len(args) == 2 and args[1] is Ellipsis:
- return parse_single_arg(args[0])
- # Then deal with fixed-length containers: Tuple[str, int, int]
- return tuple(parse_single_arg(arg) for arg in args)
+ Args:
+ field_type: pydantic field type
+ Returns:
+ ParamType | tuple[ParamType]: single click-compatible type or a tuple
+ """
+ if not is_container(field_type):
+ raise ValueError("Field type is not a container type.")
+ args = t.get_args(field_type)
+ # Early out for untyped containers: standard lists, tuples, List[Any]
+ # Use strings when the type is unknown, avoid click's type guessing
+ if len(args) == 0:
+ return click_types.convert_type(str)
+ # Early out for homogenous containers: Tuple[int], List[str]
+ if len(args) == 1:
+ return parse_single_arg(args[0])
+ # Early out for homogenous tuples of indefinite length: Tuple[int, ...]
+ if len(args) == 2 and args[1] is Ellipsis:
+ return parse_single_arg(args[0])
+ # Then deal with fixed-length containers: Tuple[str, int, int]
+ return tuple(parse_single_arg(arg) for arg in args)
def parse_single_arg(arg: type) -> ParamType:
- """Returns the click-compatible type for container origin types.
+ """Returns the click-compatible type for container origin types.
- In this case, returns string when it's not inferrable, a JSON for mappings
- and the original type itself in every other case (ints, floats and so on).
- Bytes is a special case, not natively handled by click.
+ In this case, returns string when it's not inferrable, a JSON for mappings
+ and the original type itself in every other case (ints, floats and so on).
+ Bytes is a special case, not natively handled by click.
- Args:
- arg (type): single argument
-
- Returns:
- ParamType: click-compatible type
- """
- from . import lenient_issubclass
- # When we don't know the type, we choose 'str'
- if arg is t.Any: return click_types.convert_type(str)
- # For containers and nested models, we use JSON
- if is_container(arg): return JsonType()
- if lenient_issubclass(arg, bytes): return BytesType()
- return click_types.convert_type(arg)
+ Args:
+ arg (type): single argument
+ Returns:
+ ParamType: click-compatible type
+ """
+ from . import lenient_issubclass
+ # When we don't know the type, we choose 'str'
+ if arg is t.Any: return click_types.convert_type(str)
+ # For containers and nested models, we use JSON
+ if is_container(arg): return JsonType()
+ if lenient_issubclass(arg, bytes): return BytesType()
+ return click_types.convert_type(arg)
class BytesType(ParamType):
- name = "bytes"
-
- def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
- if isinstance(value, bytes):
- return value
- try:
- return str.encode(value)
- except Exception as exc:
- self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
+ name = "bytes"
+ def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
+ if isinstance(value, bytes):
+ return value
+ try:
+ return str.encode(value)
+ except Exception as exc:
+ self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
CYGWIN = sys.platform.startswith("cygwin")
WIN = sys.platform.startswith("win")
if sys.platform.startswith("win") and WIN:
- def _get_argv_encoding() -> str:
- import locale
+ def _get_argv_encoding() -> str:
+ import locale
- return locale.getpreferredencoding()
+ return locale.getpreferredencoding()
else:
- def _get_argv_encoding() -> str:
- return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
-
+ def _get_argv_encoding() -> str:
+ return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
class CudaValueType(ParamType):
- name = "cuda"
- envvar_list_splitter = ","
- is_composite = True
- typ = click_types.convert_type(str)
+ name = "cuda"
+ envvar_list_splitter = ","
+ is_composite = True
+ typ = click_types.convert_type(str)
- def split_envvar_value(self, rv: str) -> t.Sequence[str]:
- var = tuple(i for i in rv.split(self.envvar_list_splitter))
- if "-1" in var:
- return var[: var.index("-1")]
- return var
+ def split_envvar_value(self, rv: str) -> t.Sequence[str]:
+ var = tuple(i for i in rv.split(self.envvar_list_splitter))
+ if "-1" in var:
+ return var[:var.index("-1")]
+ return var
- def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
- """Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
+ def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
+ """Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
- Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
+ Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
- Args:
- ctx: Invocation context for this command.
- param: The parameter that is requesting completion.
- incomplete: Value being completed. May be empty.
- """
- from ..utils import available_devices
+ Args:
+ ctx: Invocation context for this command.
+ param: The parameter that is requesting completion.
+ incomplete: Value being completed. May be empty.
+ """
+ from ..utils import available_devices
- mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
+ mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
- return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
+ return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
- def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
- if isinstance(value, bytes):
- enc = _get_argv_encoding()
- try:
- value = value.decode(enc)
- except UnicodeError:
- fs_enc = sys.getfilesystemencoding()
- if fs_enc != enc:
- try:
- value = value.decode(fs_enc)
- except UnicodeError:
- value = value.decode("utf-8", "replace")
- else:
- value = value.decode("utf-8", "replace")
+ def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
+ if isinstance(value, bytes):
+ enc = _get_argv_encoding()
+ try:
+ value = value.decode(enc)
+ except UnicodeError:
+ fs_enc = sys.getfilesystemencoding()
+ if fs_enc != enc:
+ try:
+ value = value.decode(fs_enc)
+ except UnicodeError:
+ value = value.decode("utf-8", "replace")
+ else:
+ value = value.decode("utf-8", "replace")
- return tuple(self.typ(x, param, ctx) for x in value.split(","))
-
- def __repr__(self) -> str:
- """CUDA is a click.STRING extension."""
- return "STRING"
+ return tuple(self.typ(x, param, ctx) for x in value.split(","))
+ def __repr__(self) -> str:
+ """CUDA is a click.STRING extension."""
+ return "STRING"
CUDA = CudaValueType()
-
class JsonType(ParamType):
- name = "json"
+ name = "json"
- def __init__(self, should_load: bool = True) -> None:
- """Support JSON type for click.ParamType.
+ def __init__(self, should_load: bool = True) -> None:
+ """Support JSON type for click.ParamType.
- Args:
- should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
- """
- super().__init__()
- self.should_load = should_load
+ Args:
+ should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
+ """
+ super().__init__()
+ self.should_load = should_load
- def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
- from . import LazyType
- if LazyType[t.Mapping[str, str]](t.Mapping[str, str]).isinstance(value) or not self.should_load: return value
- try: return orjson.loads(value)
- except orjson.JSONDecodeError as exc: self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)
+ def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
+ from . import LazyType
+ if LazyType[t.Mapping[str, str]](t.Mapping[str, str]).isinstance(value) or not self.should_load: return value
+ try:
+ return orjson.loads(value)
+ except orjson.JSONDecodeError as exc:
+ self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)
diff --git a/src/openllm/utils/dummy_flax_objects.py b/src/openllm/utils/dummy_flax_objects.py
index c30717cc..36bb4595 100644
--- a/src/openllm/utils/dummy_flax_objects.py
+++ b/src/openllm/utils/dummy_flax_objects.py
@@ -19,27 +19,24 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
if t.TYPE_CHECKING:
- from ..models.auto.factory import _LazyAutoMapping
+ from ..models.auto.factory import _LazyAutoMapping
class FlaxFlanT5(metaclass=DummyMetaclass):
- _backends = ["flax"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["flax"])
+ _backends = ["flax"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["flax"])
class FlaxOPT(metaclass=DummyMetaclass):
- _backends = ["flax"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["flax"])
+ _backends = ["flax"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["flax"])
class AutoFlaxLLM(metaclass=DummyMetaclass):
- _backends = ["flax"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["flax"])
+ _backends = ["flax"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["flax"])
MODEL_FLAX_MAPPING = t.cast("_LazyAutoMapping", None)
diff --git a/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py b/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py
index 6830e5fb..0f2b4031 100644
--- a/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py
+++ b/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py
@@ -19,14 +19,13 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
class ChatGLM(metaclass=DummyMetaclass):
- _backends = ["torch", "cpm_kernels"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch", "cpm_kernels"])
+ _backends = ["torch", "cpm_kernels"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch", "cpm_kernels"])
class Baichuan(metaclass=DummyMetaclass):
- _backends = ["torch", "cpm_kernels"]
+ _backends = ["torch", "cpm_kernels"]
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch", "cpm_kernels"])
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch", "cpm_kernels"])
diff --git a/src/openllm/utils/dummy_pt_and_einops_objects.py b/src/openllm/utils/dummy_pt_and_einops_objects.py
index e10dd6ff..e3b54dfe 100644
--- a/src/openllm/utils/dummy_pt_and_einops_objects.py
+++ b/src/openllm/utils/dummy_pt_and_einops_objects.py
@@ -19,7 +19,7 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
class Falcon(metaclass=DummyMetaclass):
- _backends = ["torch", "einops"]
+ _backends = ["torch", "einops"]
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch", "einops"])
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch", "einops"])
diff --git a/src/openllm/utils/dummy_pt_and_triton_objects.py b/src/openllm/utils/dummy_pt_and_triton_objects.py
index e0ff894a..451eb77b 100644
--- a/src/openllm/utils/dummy_pt_and_triton_objects.py
+++ b/src/openllm/utils/dummy_pt_and_triton_objects.py
@@ -19,7 +19,7 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
class MPT(metaclass=DummyMetaclass):
- _backends = ["torch", "triton"]
+ _backends = ["torch", "triton"]
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch"])
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch"])
diff --git a/src/openllm/utils/dummy_pt_objects.py b/src/openllm/utils/dummy_pt_objects.py
index c03d2954..5b518fc8 100644
--- a/src/openllm/utils/dummy_pt_objects.py
+++ b/src/openllm/utils/dummy_pt_objects.py
@@ -18,62 +18,54 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
if t.TYPE_CHECKING:
- from ..models.auto.factory import _LazyAutoMapping
+ from ..models.auto.factory import _LazyAutoMapping
class FlanT5(metaclass=DummyMetaclass):
- _backends = ["torch"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch"])
+ _backends = ["torch"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch"])
class OPT(metaclass=DummyMetaclass):
- _backends = ["torch"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch"])
+ _backends = ["torch"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch"])
class GPTNeoX(metaclass=DummyMetaclass):
- _backends = ["torch"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch"])
+ _backends = ["torch"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch"])
class DollyV2(metaclass=DummyMetaclass):
- _backends = ["torch"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch"])
+ _backends = ["torch"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch"])
class StarCoder(metaclass=DummyMetaclass):
- _backends = ["torch"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch"])
+ _backends = ["torch"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch"])
class StableLM(metaclass=DummyMetaclass):
- _backends = ["torch"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch"])
+ _backends = ["torch"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch"])
class Llama(metaclass=DummyMetaclass):
- _backends = ["torch"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch"])
+ _backends = ["torch"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch"])
class AutoLLM(metaclass=DummyMetaclass):
- _backends = ["torch"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["torch"])
+ _backends = ["torch"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["torch"])
MODEL_MAPPING = t.cast("_LazyAutoMapping", None)
diff --git a/src/openllm/utils/dummy_tf_objects.py b/src/openllm/utils/dummy_tf_objects.py
index ff7d2acd..ee83a12c 100644
--- a/src/openllm/utils/dummy_tf_objects.py
+++ b/src/openllm/utils/dummy_tf_objects.py
@@ -18,27 +18,24 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
if t.TYPE_CHECKING:
- from ..models.auto.factory import _LazyAutoMapping
+ from ..models.auto.factory import _LazyAutoMapping
class TFFlanT5(metaclass=DummyMetaclass):
- _backends = ["tf"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["tf"])
+ _backends = ["tf"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["tf"])
class TFOPT(metaclass=DummyMetaclass):
- _backends = ["tf"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["tf"])
+ _backends = ["tf"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["tf"])
class AutoTFLLM(metaclass=DummyMetaclass):
- _backends = ["tf"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["tf"])
+ _backends = ["tf"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["tf"])
MODEL_TF_MAPPING = t.cast("_LazyAutoMapping", None)
diff --git a/src/openllm/utils/dummy_vllm_objects.py b/src/openllm/utils/dummy_vllm_objects.py
index 2e1b2832..bb819e33 100644
--- a/src/openllm/utils/dummy_vllm_objects.py
+++ b/src/openllm/utils/dummy_vllm_objects.py
@@ -18,26 +18,24 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
if t.TYPE_CHECKING:
- from ..models.auto.factory import _LazyAutoMapping
+ from ..models.auto.factory import _LazyAutoMapping
class VLLMLlama(metaclass=DummyMetaclass):
- _backends = ["vllm"]
+ _backends = ["vllm"]
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["vllm"])
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["vllm"])
class VLLMOPT(metaclass=DummyMetaclass):
- _backends = ["vllm"]
-
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["vllm"])
+ _backends = ["vllm"]
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["vllm"])
class AutoVLLM(metaclass=DummyMetaclass):
- _backends = ["vllm"]
+ _backends = ["vllm"]
- def __init__(self, *args: t.Any, **attrs: t.Any):
- require_backends(self, ["vllm"])
+ def __init__(self, *args: t.Any, **attrs: t.Any):
+ require_backends(self, ["vllm"])
-
-MODEL_VLLM_MAPPING = t.cast("_LazyAutoMapping", None)
+MODEL_VLLM_MAPPING = t.cast("_LazyAutoMapping", None)
diff --git a/src/openllm/utils/import_utils.py b/src/openllm/utils/import_utils.py
index ac9b587b..9bf4c18a 100644
--- a/src/openllm/utils/import_utils.py
+++ b/src/openllm/utils/import_utils.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
from __future__ import annotations
import functools
@@ -36,37 +35,27 @@ from .representation import ReprMixin
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
- from typing import overload
+ from typing import overload
else:
- from typing_extensions import overload
+ from typing_extensions import overload
if t.TYPE_CHECKING:
- BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
- from .._types import LiteralRuntime
- from .._types import P
- from .._types import T
+ BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
+ from .._types import LiteralRuntime
+ from .._types import P
+ from .._types import T
- class _AnnotatedLazyLoader(LazyLoader, t.Generic[T]):
- DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[[T], t.LiteralString]
- PROMPT_MAPPING: dict[T, t.LiteralString] | None
+ class _AnnotatedLazyLoader(LazyLoader, t.Generic[T]):
+ DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[[T], t.LiteralString]
+ PROMPT_MAPPING: dict[T, t.LiteralString] | None
else:
- _AnnotatedLazyLoader = LazyLoader
- BackendOrderredDict = OrderedDict
+ _AnnotatedLazyLoader = LazyLoader
+ BackendOrderredDict = OrderedDict
logger = logging.getLogger(__name__)
-OPTIONAL_DEPENDENCIES = {
- "opt",
- "flan-t5",
- "vllm",
- "fine-tune",
- "ggml",
- "agents",
- "openai",
- "playground",
- "gptq",
-}
+OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq",}
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
@@ -74,14 +63,14 @@ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
-
def _is_package_available(package: str) -> bool:
- _package_available = importlib.util.find_spec(package) is not None
- if _package_available:
- try: importlib.metadata.version(package)
- except importlib.metadata.PackageNotFoundError: _package_available = False
- return _package_available
-
+ _package_available = importlib.util.find_spec(package) is not None
+ if _package_available:
+ try:
+ importlib.metadata.version(package)
+ except importlib.metadata.PackageNotFoundError:
+ _package_available = False
+ return _package_available
_torch_available = importlib.util.find_spec("torch") is not None
_tf_available = importlib.util.find_spec("tensorflow") is not None
@@ -98,98 +87,115 @@ _jupytext_available = _is_package_available("jupytext")
_notebook_available = _is_package_available("notebook")
_autogptq_available = _is_package_available("auto_gptq")
-def is_transformers_supports_kbit() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 30)
-def is_transformers_supports_agent() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 29)
-def is_jupyter_available() -> bool: return _jupyter_available
-def is_jupytext_available() -> bool: return _jupytext_available
-def is_notebook_available() -> bool: return _notebook_available
-def is_triton_available() -> bool: return _triton_available
-def is_datasets_available() -> bool: return _datasets_available
-def is_peft_available() -> bool: return _peft_available
-def is_einops_available() -> bool: return _einops_available
-def is_cpm_kernels_available() -> bool: return _cpm_kernel_available
-def is_bitsandbytes_available() -> bool: return _bitsandbytes_available
-def is_autogptq_available() -> bool: return _autogptq_available
-def is_vllm_available() -> bool: return _vllm_available
+def is_transformers_supports_kbit() -> bool:
+ return pkg.pkg_version_info("transformers")[:2] >= (4, 30)
+
+def is_transformers_supports_agent() -> bool:
+ return pkg.pkg_version_info("transformers")[:2] >= (4, 29)
+
+def is_jupyter_available() -> bool:
+ return _jupyter_available
+
+def is_jupytext_available() -> bool:
+ return _jupytext_available
+
+def is_notebook_available() -> bool:
+ return _notebook_available
+
+def is_triton_available() -> bool:
+ return _triton_available
+
+def is_datasets_available() -> bool:
+ return _datasets_available
+
+def is_peft_available() -> bool:
+ return _peft_available
+
+def is_einops_available() -> bool:
+ return _einops_available
+
+def is_cpm_kernels_available() -> bool:
+ return _cpm_kernel_available
+
+def is_bitsandbytes_available() -> bool:
+ return _bitsandbytes_available
+
+def is_autogptq_available() -> bool:
+ return _autogptq_available
+
+def is_vllm_available() -> bool:
+ return _vllm_available
def is_torch_available() -> bool:
- global _torch_available
- if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
- if _torch_available:
- try: importlib.metadata.version("torch")
- except importlib.metadata.PackageNotFoundError: _torch_available = False
- else:
- logger.info("Disabling PyTorch because USE_TF is set")
+ global _torch_available
+ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
+ if _torch_available:
+ try:
+ importlib.metadata.version("torch")
+ except importlib.metadata.PackageNotFoundError:
_torch_available = False
- return _torch_available
+ else:
+ logger.info("Disabling PyTorch because USE_TF is set")
+ _torch_available = False
+ return _torch_available
def is_tf_available() -> bool:
- global _tf_available
- if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: _tf_available = True
- else:
+ global _tf_available
+ if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: _tf_available = True
+ else:
+ _tf_version = None
+ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
+ if _tf_available:
+ candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos", "tensorflow-aarch64",)
_tf_version = None
- if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
- if _tf_available:
- candidates = (
- "tensorflow",
- "tensorflow-cpu",
- "tensorflow-gpu",
- "tf-nightly",
- "tf-nightly-cpu",
- "tf-nightly-gpu",
- "intel-tensorflow",
- "intel-tensorflow-avx512",
- "tensorflow-rocm",
- "tensorflow-macos",
- "tensorflow-aarch64",
- )
- _tf_version = None
- # For the metadata, we have to look for both tensorflow and tensorflow-cpu
- for _pkg in candidates:
- try:
- _tf_version = importlib.metadata.version(_pkg)
- break
- except importlib.metadata.PackageNotFoundError: pass
- _tf_available = _tf_version is not None
- if _tf_available:
- if _tf_version and version.parse(_tf_version) < version.parse("2"):
- logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
- _tf_available = False
- else:
- logger.info("Disabling Tensorflow because USE_TORCH is set")
- _tf_available = False
- return _tf_available
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
+ for _pkg in candidates:
+ try:
+ _tf_version = importlib.metadata.version(_pkg)
+ break
+ except importlib.metadata.PackageNotFoundError:
+ pass
+ _tf_available = _tf_version is not None
+ if _tf_available:
+ if _tf_version and version.parse(_tf_version) < version.parse("2"):
+ logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
+ _tf_available = False
+ else:
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
+ _tf_available = False
+ return _tf_available
def is_flax_available() -> bool:
- global _flax_available
- if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
- if _flax_available:
- try:
- importlib.metadata.version("jax")
- importlib.metadata.version("flax")
- except importlib.metadata.PackageNotFoundError: _flax_available = False
- else: _flax_available = False
- return _flax_available
-
+ global _flax_available
+ if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ if _flax_available:
+ try:
+ importlib.metadata.version("jax")
+ importlib.metadata.version("flax")
+ except importlib.metadata.PackageNotFoundError:
+ _flax_available = False
+ else:
+ _flax_available = False
+ return _flax_available
def requires_dependencies(package: str | list[str], *, extra: str | list[str] | None = None) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
- import openllm.utils
+ import openllm.utils
- if isinstance(package, str): package = [package]
- if isinstance(extra, str): extra = [extra]
+ if isinstance(package, str): package = [package]
+ if isinstance(extra, str): extra = [extra]
- def decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
- @functools.wraps(func)
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
- for p in package:
- cached_check: t.Callable[[], bool] | None = getattr(openllm.utils, f"is_{p}_available", None)
- if not ((cached_check is not None and cached_check()) or _is_package_available(p)): raise ImportError( f"{func.__name__} requires '{p}' to be available locally (Currently missing). Make sure to have {p} to be installed: 'pip install \"{p if not extra else 'openllm['+', '.join(extra)+']'}\"'")
- return func(*args, **kwargs)
+ def decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
+ @functools.wraps(func)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
+ for p in package:
+ cached_check: t.Callable[[], bool] | None = getattr(openllm.utils, f"is_{p}_available", None)
+ if not ((cached_check is not None and cached_check()) or _is_package_available(p)):
+ raise ImportError(f"{func.__name__} requires '{p}' to be available locally (Currently missing). Make sure to have {p} to be installed: 'pip install \"{p if not extra else 'openllm['+', '.join(extra)+']'}\"'")
+ return func(*args, **kwargs)
- return wrapper
-
- return decorator
+ return wrapper
+ return decorator
VLLM_IMPORT_ERROR_WITH_PYTORCH = """\
{0} requires the vLLM library but it was not found in your environment.
@@ -250,7 +256,6 @@ Checkout the instructions on the installation page: https://www.tensorflow.org/i
ones that match your environment. Please note that you may need to restart your runtime after installation.
"""
-
FLAX_IMPORT_ERROR = """{0} requires the FLAX library but it was not found in your environment.
Checkout the instructions on the installation page: https://github.com/google/flax and follow the
ones that match your environment. Please note that you may need to restart your runtime after installation.
@@ -301,142 +306,117 @@ You can install it with pip: `pip install auto-gptq`. Please note that you may n
your runtime after installation.
"""
-BACKENDS_MAPPING = BackendOrderredDict(
- [
- ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
- ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
- ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
- ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)),
- ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)),
- ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
- ("triton", (is_triton_available, TRITON_IMPORT_ERROR)),
- ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
- ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
- ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
- ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)),
- ]
-)
-
+BACKENDS_MAPPING = BackendOrderredDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
+ ("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)),])
class DummyMetaclass(ABCMeta):
- """Metaclass for dummy object.
+ """Metaclass for dummy object.
- It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
- """
+ It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
+ """
- _backends: t.List[str]
-
- def __getattribute__(cls, key: str) -> t.Any:
- if key.startswith("_"): return super().__getattribute__(key)
- require_backends(cls, cls._backends)
+ _backends: t.List[str]
+ def __getattribute__(cls, key: str) -> t.Any:
+ if key.startswith("_"): return super().__getattribute__(key)
+ require_backends(cls, cls._backends)
def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
- if not isinstance(backends, (list, tuple)): backends = list(backends)
- name = o.__name__ if hasattr(o, "__name__") else o.__class__.__name__
- # Raise an error for users who might not realize that classes without "TF" are torch-only
- if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
- # Raise the inverse error for PyTorch users trying to load TF classes
- if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
- # Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax
- if "vllm" in backends:
- if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
- if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
- if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
-
- checks = (BACKENDS_MAPPING[backend] for backend in backends)
- failed = [msg.format(name) for available, msg in checks if not available()]
- if failed: raise ImportError("".join(failed))
+ if not isinstance(backends, (list, tuple)): backends = list(backends)
+ name = o.__name__ if hasattr(o, "__name__") else o.__class__.__name__
+ # Raise an error for users who might not realize that classes without "TF" are torch-only
+ if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
+ # Raise the inverse error for PyTorch users trying to load TF classes
+ if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
+ # Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax
+ if "vllm" in backends:
+ if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
+ if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
+ if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
+ failed = [msg.format(name) for available, msg in checks if not available()]
+ if failed: raise ImportError("".join(failed))
class EnvVarMixin(ReprMixin):
- model_name: str
+ model_name: str
- @property
- def __repr_keys__(self) -> set[str]:
- return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}
+ @property
+ def __repr_keys__(self) -> set[str]:
+ return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}
- if t.TYPE_CHECKING:
- config: str
- model_id: str
- quantize: str
- framework: str
- bettertransformer: str
- runtime: t.Literal["ggml", "transformers"]
+ if t.TYPE_CHECKING:
+ config: str
+ model_id: str
+ quantize: str
+ framework: str
+ bettertransformer: str
+ runtime: t.Literal["ggml", "transformers"]
- framework_value: LiteralRuntime
- quantize_value: t.Literal["int8", "int4", "gptq"] | None
- bettertransformer_value: bool | None
- model_id_value: str | None
- runtime_value: t.Literal["ggml", "transformers"]
+ framework_value: LiteralRuntime
+ quantize_value: t.Literal["int8", "int4", "gptq"] | None
+ bettertransformer_value: bool | None
+ model_id_value: str | None
+ runtime_value: t.Literal["ggml", "transformers"]
- # fmt: off
- @overload
- def __getitem__(self, item: t.Literal["config"]) -> str: ...
- @overload
- def __getitem__(self, item: t.Literal["model_id"]) -> str: ...
- @overload
- def __getitem__(self, item: t.Literal["quantize"]) -> str: ...
- @overload
- def __getitem__(self, item: t.Literal["framework"]) -> str: ...
- @overload
- def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
- @overload
- def __getitem__(self, item: t.Literal["runtime"]) -> str: ...
- @overload
- def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ...
- @overload
- def __getitem__(self, item: t.Literal["quantize_value"]) -> t.Literal["int8", "int4", "gptq"] | None: ...
- @overload
- def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ...
- @overload
- def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> bool: ...
- @overload
- def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ...
- # fmt: on
+ # fmt: off
+ @overload
+ def __getitem__(self, item: t.Literal["config"]) -> str: ...
+ @overload
+ def __getitem__(self, item: t.Literal["model_id"]) -> str: ...
+ @overload
+ def __getitem__(self, item: t.Literal["quantize"]) -> str: ...
+ @overload
+ def __getitem__(self, item: t.Literal["framework"]) -> str: ...
+ @overload
+ def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
+ @overload
+ def __getitem__(self, item: t.Literal["runtime"]) -> str: ...
+ @overload
+ def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ...
+ @overload
+ def __getitem__(self, item: t.Literal["quantize_value"]) -> t.Literal["int8", "int4", "gptq"] | None: ...
+ @overload
+ def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ...
+ @overload
+ def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> bool: ...
+ @overload
+ def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ...
+ # fmt: on
- def __getitem__(self, item: str | t.Any) -> t.Any:
- if hasattr(self, item): return getattr(self, item)
- raise KeyError(f"Key {item} not found in {self}")
+ def __getitem__(self, item: str | t.Any) -> t.Any:
+ if hasattr(self, item): return getattr(self, item)
+ raise KeyError(f"Key {item} not found in {self}")
- def __new__(
- cls,
- model_name: str,
- implementation: LiteralRuntime = "pt",
- model_id: str | None = None,
- bettertransformer: bool | None = None,
- quantize: t.LiteralString | None = None,
- runtime: t.Literal["ggml", "transformers"] = "transformers",
- ) -> t.Self:
- from . import codegen
- from .._configuration import field_env_key
- model_name = inflection.underscore(model_name)
+ def __new__(cls, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers",) -> t.Self:
+ from . import codegen
+ from .._configuration import field_env_key
+ model_name = inflection.underscore(model_name)
- res = super().__new__(cls)
- res.model_name = model_name
+ res = super().__new__(cls)
+ res.model_name = model_name
- # gen properties env key
- for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: setattr(res, att, field_env_key(model_name, att.upper()))
+ # gen properties env key
+ for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}:
+ setattr(res, att, field_env_key(model_name, att.upper()))
- # gen properties env value
- attributes_with_values = {
- "framework": (str, implementation),
- "quantize": (str, quantize),
- "bettertransformer": (bool, bettertransformer),
- "model_id": (str, model_id),
- "runtime": (str, runtime),
- }
- globs: dict[str, t.Any] = {"__bool_vars_value": ENV_VARS_TRUE_VALUES, "__env_get": os.getenv, "self": res}
+ # gen properties env value
+ attributes_with_values = {"framework": (str, implementation), "quantize": (str, quantize), "bettertransformer": (bool, bettertransformer), "model_id": (str, model_id), "runtime": (str, runtime),}
+ globs: dict[str, t.Any] = {"__bool_vars_value": ENV_VARS_TRUE_VALUES, "__env_get": os.getenv, "self": res}
- for attribute, (default_type, default_value) in attributes_with_values.items():
- lines: list[str] = []
- if default_type is bool: lines.append(f"return str(__env_get(self['{attribute}'], str(__env_default)).upper() in __bool_vars_value)")
- else: lines.append(f"return __env_get(self['{attribute}'], __env_default)")
+ for attribute, (default_type, default_value) in attributes_with_values.items():
+ lines: list[str] = []
+ if default_type is bool: lines.append(f"return str(__env_get(self['{attribute}'], str(__env_default)).upper() in __bool_vars_value)")
+ else: lines.append(f"return __env_get(self['{attribute}'], __env_default)")
- setattr(res, f"{attribute}_value", codegen.generate_function(cls, "_env_get_" + attribute, lines, ("__env_default",), globs)(default_value))
+ setattr(res, f"{attribute}_value", codegen.generate_function(cls, "_env_get_" + attribute, lines, ("__env_default",), globs)(default_value))
- return res
- @property
- def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
- @property
- def module(self) -> _AnnotatedLazyLoader[t.LiteralString]: return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
+ return res
+
+ @property
+ def start_docstring(self) -> str:
+ return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
+
+ @property
+ def module(self) -> _AnnotatedLazyLoader[t.LiteralString]:
+ return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
diff --git a/src/openllm/utils/lazy.py b/src/openllm/utils/lazy.py
index bff0a51e..5fc5972f 100644
--- a/src/openllm/utils/lazy.py
+++ b/src/openllm/utils/lazy.py
@@ -30,178 +30,176 @@ from ..exceptions import ForbiddenAttributeError
from ..exceptions import OpenLLMException
class UsageNotAllowedError(OpenLLMException):
- """Raised when LazyModule.__getitem__ is forbidden."""
+ """Raised when LazyModule.__getitem__ is forbidden."""
+
class MissingAttributesError(OpenLLMException):
- """Raised when given keys is not available in LazyModule special mapping."""
+ """Raised when given keys is not available in LazyModule special mapping."""
@functools.total_ordering
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
class VersionInfo:
- """A version object that can be compared to tuple of length 1--4.
+ """A version object that can be compared to tuple of length 1--4.
- ```python
- >>> VersionInfo(19, 1, 0, "final") <= (19, 2)
- True
- >>> VersionInfo(19, 1, 0, "final") < (19, 1, 1)
- True
- >>> vi = VersionInfo(19, 2, 0, "final")
- >>> vi < (19, 1, 1)
- False
- >>> vi < (19,)
- False
- >>> vi == (19, 2,)
- True
- >>> vi == (19, 2, 1)
- False
- ```
- Vendorred from attrs.
+ ```python
+ >>> VersionInfo(19, 1, 0, "final") <= (19, 2)
+ True
+ >>> VersionInfo(19, 1, 0, "final") < (19, 1, 1)
+ True
+ >>> vi = VersionInfo(19, 2, 0, "final")
+ >>> vi < (19, 1, 1)
+ False
+ >>> vi < (19,)
+ False
+ >>> vi == (19, 2,)
+ True
+ >>> vi == (19, 2, 1)
+ False
+ ```
+ Vendorred from attrs.
+ """
+ major: int = attr.field()
+ minor: int = attr.field()
+ micro: int = attr.field()
+ releaselevel: str = attr.field()
+
+ @classmethod
+ def from_version_string(cls, s: str) -> VersionInfo:
+ """Parse *s* and return a VersionInfo."""
+ v = s.split(".")
+ if len(v) == 3: v.append("final")
+ return cls(major=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3])
+
+ def _ensure_tuple(self, other: VersionInfo) -> tuple[tuple[int, int, int, str], tuple[int, int, int, str]]:
+ """Ensure *other* is a tuple of a valid length.
+
+ Returns a possibly transformed *other* and ourselves as a tuple of
+ the same length as *other*.
"""
- major: int = attr.field()
- minor: int = attr.field()
- micro: int = attr.field()
- releaselevel: str = attr.field()
+ cmp = attr.astuple(other) if self.__class__ is other.__class__ else other
+ if not isinstance(cmp, tuple): raise NotImplementedError
+ if not (1 <= len(cmp) <= 4): raise NotImplementedError
+ return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[:len(cmp)]), t.cast(t.Tuple[int, int, int, str], cmp)
- @classmethod
- def from_version_string(cls, s: str) -> VersionInfo:
- """Parse *s* and return a VersionInfo."""
- v = s.split(".")
- if len(v) == 3: v.append("final")
- return cls(major=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3])
- def _ensure_tuple(self, other: VersionInfo) -> tuple[tuple[int, int, int, str], tuple[int, int, int, str]]:
- """Ensure *other* is a tuple of a valid length.
+ def __eq__(self, other: t.Any) -> bool:
+ try:
+ us, them = self._ensure_tuple(other)
+ except NotImplementedError:
+ return NotImplemented
+ return us == them
- Returns a possibly transformed *other* and ourselves as a tuple of
- the same length as *other*.
- """
- cmp = attr.astuple(other) if self.__class__ is other.__class__ else other
- if not isinstance(cmp, tuple): raise NotImplementedError
- if not (1 <= len(cmp) <= 4): raise NotImplementedError
- return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[: len(cmp)]), t.cast(t.Tuple[int, int, int, str], cmp)
- def __eq__(self, other: t.Any) -> bool:
- try: us, them = self._ensure_tuple(other)
- except NotImplementedError: return NotImplemented
- return us == them
- def __lt__(self, other: t.Any) -> bool:
- try: us, them = self._ensure_tuple(other)
- except NotImplementedError: return NotImplemented
- # Since alphabetically "dev0" < "final" < "post1" < "post2", we don't
- # have to do anything special with releaselevel for now.
- return us < them
+ def __lt__(self, other: t.Any) -> bool:
+ try:
+ us, them = self._ensure_tuple(other)
+ except NotImplementedError:
+ return NotImplemented
+ # Since alphabetically "dev0" < "final" < "post1" < "post2", we don't
+ # have to do anything special with releaselevel for now.
+ return us < them
_sentinel, _reserved_namespace = object(), {"__openllm_special__", "__openllm_migration__"}
class LazyModule(types.ModuleType):
- """Module class that surfaces all objects but only performs associated imports when the objects are requested.
+ """Module class that surfaces all objects but only performs associated imports when the objects are requested.
- This is a direct port from transformers.utils.import_utils._LazyModule for backwards compatibility with transformers < 4.18.
+ This is a direct port from transformers.utils.import_utils._LazyModule for backwards compatibility with transformers < 4.18.
- This is an extension a more powerful LazyLoader.
+ This is an extension a more powerful LazyLoader.
+ """
+
+ # Very heavily inspired by optuna.integration._IntegrationModule
+ # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
+ def __init__(self, name: str, module_file: str, import_structure: dict[str, list[str]], module_spec: importlib.machinery.ModuleSpec | None = None, doc: str | None = None, extra_objects: dict[str, t.Any] | None = None,):
+ """Lazily load this module as an object.
+
+ It does instantiate a __all__ and __dir__ for IDE support
+
+ Args:
+ name: module name
+ module_file: the given file. Often default to 'globals()['__file__']'
+ import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module'
+ module_spec: __spec__ of the lazily loaded module
+ doc: Optional docstring for this module.
+ extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well
+ as any locals() functions
"""
+ super().__init__(name)
+ self._modules = set(import_structure.keys())
+ self._class_to_module: dict[str, str] = {}
+ _extra_objects = {} if extra_objects is None else extra_objects
+ for key, values in import_structure.items():
+ for value in values:
+ self._class_to_module[value] = key
+ # Needed for autocompletion in an IDE
+ self.__all__ = list(import_structure.keys()) + list(itertools.chain(*import_structure.values()))
+ self.__file__ = module_file
+ self.__spec__ = module_spec
+ self.__path__ = [os.path.dirname(module_file)]
+ self.__doc__ = doc
+ self._objects = _extra_objects
+ self._name = name
+ self._import_structure = import_structure
- # Very heavily inspired by optuna.integration._IntegrationModule
- # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
- def __init__(
- self,
- name: str,
- module_file: str,
- import_structure: dict[str, list[str]],
- module_spec: importlib.machinery.ModuleSpec | None = None,
- doc: str | None = None,
- extra_objects: dict[str, t.Any] | None = None,
- ):
- """Lazily load this module as an object.
+ def __dir__(self) -> list[str]:
+ """Needed for autocompletion in an IDE."""
+ result = t.cast("list[str]", super().__dir__())
+ # The elements of self.__all__ that are submodules may or
+ # may not be in the dir already, depending on whether
+ # they have been accessed or not. So we only add the
+ # elements of self.__all__ that are not already in the dir.
+ return result + [i for i in self.__all__ if i not in result]
- It does instantiate a __all__ and __dir__ for IDE support
+ def __getitem__(self, key: str) -> t.Any:
+ """This is reserved to only internal uses and users shouldn't use this."""
+ if self._objects.get("__openllm_special__") is None: raise UsageNotAllowedError(f"'{self._name}' is not allowed to be used as a dict.")
+ _special_mapping = self._objects.get("__openllm_special__", {})
+ try:
+ if key in _special_mapping: return getattr(self, _special_mapping.__getitem__(key))
+ raise MissingAttributesError(f"Requested '{key}' is not available in given mapping.")
+ except AttributeError as e:
+ raise KeyError(f"'{self._name}' has no attribute {_special_mapping[key]}") from e
+ except Exception as e:
+ raise KeyError(f"Failed to lookup '{key}' in '{self._name}'") from e
- Args:
- name: module name
- module_file: the given file. Often default to 'globals()['__file__']'
- import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module'
- module_spec: __spec__ of the lazily loaded module
- doc: Optional docstring for this module.
- extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well
- as any locals() functions
- """
- super().__init__(name)
- self._modules = set(import_structure.keys())
- self._class_to_module: dict[str, str] = {}
- _extra_objects = {} if extra_objects is None else extra_objects
- for key, values in import_structure.items():
- for value in values:
- self._class_to_module[value] = key
- # Needed for autocompletion in an IDE
- self.__all__ = list(import_structure.keys()) + list(itertools.chain(*import_structure.values()))
- self.__file__ = module_file
- self.__spec__ = module_spec
- self.__path__ = [os.path.dirname(module_file)]
- self.__doc__ = doc
- self._objects = _extra_objects
- self._name = name
- self._import_structure = import_structure
- def __dir__(self) -> list[str]:
- """Needed for autocompletion in an IDE."""
- result = t.cast("list[str]", super().__dir__())
- # The elements of self.__all__ that are submodules may or
- # may not be in the dir already, depending on whether
- # they have been accessed or not. So we only add the
- # elements of self.__all__ that are not already in the dir.
- return result + [i for i in self.__all__ if i not in result]
- def __getitem__(self, key: str) -> t.Any:
- """This is reserved to only internal uses and users shouldn't use this."""
- if self._objects.get("__openllm_special__") is None: raise UsageNotAllowedError(f"'{self._name}' is not allowed to be used as a dict.")
- _special_mapping = self._objects.get("__openllm_special__", {})
- try:
- if key in _special_mapping: return getattr(self, _special_mapping.__getitem__(key))
- raise MissingAttributesError(f"Requested '{key}' is not available in given mapping.")
- except AttributeError as e: raise KeyError(f"'{self._name}' has no attribute {_special_mapping[key]}") from e
- except Exception as e: raise KeyError(f"Failed to lookup '{key}' in '{self._name}'") from e
- def __getattr__(self, name: str) -> t.Any:
- """Equivocal __getattr__ implementation.
+ def __getattr__(self, name: str) -> t.Any:
+ """Equivocal __getattr__ implementation.
- It checks from _objects > _modules and does it recursively.
+ It checks from _objects > _modules and does it recursively.
- It also contains a special case for all of the metadata information, such as __version__ and __version_info__.
- """
- if name in _reserved_namespace: raise ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.")
- dunder_to_metadata = {
- "__title__": "Name",
- "__copyright__": "",
- "__version__": "version",
- "__version_info__": "version",
- "__description__": "summary",
- "__uri__": "",
- "__url__": "",
- "__author__": "",
- "__email__": "",
- "__license__": "license",
- "__homepage__": "",
- }
- if name in dunder_to_metadata:
- if name not in {"__version_info__", "__copyright__", "__version__"}: warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2)
- meta = importlib.metadata.metadata("openllm")
- project_url = dict(url.split(", ") for url in meta.get_all("Project-URL"))
- if name == "__license__": return "Apache-2.0"
- elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
- elif name in ("__uri__", "__url__"): return project_url["GitHub"]
- elif name == "__homepage__": return project_url["Homepage"]
- elif name == "__version_info__": return VersionInfo.from_version_string(meta["version"]) # similar to how attrs handle __version_info__
- elif name == "__author__": return meta["Author-email"].rsplit(" ", 1)[0]
- elif name == "__email__": return meta["Author-email"].rsplit("<", 1)[1][:-1]
- return meta[dunder_to_metadata[name]]
- if "__openllm_migration__" in self._objects:
- cur_value = self._objects["__openllm_migration__"].get(name, _sentinel)
- if cur_value is not _sentinel:
- warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3)
- return getattr(self, cur_value)
- if name in self._objects: return self._objects.__getitem__(name)
- if name in self._modules: value = self._get_module(name)
- elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name)
- else: raise AttributeError(f"module {self.__name__} has no attribute {name}")
- setattr(self, name, value)
- return value
- def _get_module(self, module_name: str) -> types.ModuleType:
- try: return importlib.import_module("." + module_name, self.__name__)
- except Exception as e: raise RuntimeError(f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}") from e
- def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]:
- """This is to ensure any given module is pickle-able."""
- return (self.__class__, (self._name, self.__file__, self._import_structure))
+ It also contains a special case for all of the metadata information, such as __version__ and __version_info__.
+ """
+ if name in _reserved_namespace: raise ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.")
+ dunder_to_metadata = {"__title__": "Name", "__copyright__": "", "__version__": "version", "__version_info__": "version", "__description__": "summary", "__uri__": "", "__url__": "", "__author__": "", "__email__": "", "__license__": "license", "__homepage__": "",}
+ if name in dunder_to_metadata:
+ if name not in {"__version_info__", "__copyright__", "__version__"}:
+ warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2)
+ meta = importlib.metadata.metadata("openllm")
+ project_url = dict(url.split(", ") for url in meta.get_all("Project-URL"))
+ if name == "__license__": return "Apache-2.0"
+ elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
+ elif name in ("__uri__", "__url__"): return project_url["GitHub"]
+ elif name == "__homepage__": return project_url["Homepage"]
+ elif name == "__version_info__": return VersionInfo.from_version_string(meta["version"]) # similar to how attrs handle __version_info__
+ elif name == "__author__": return meta["Author-email"].rsplit(" ", 1)[0]
+ elif name == "__email__": return meta["Author-email"].rsplit("<", 1)[1][:-1]
+ return meta[dunder_to_metadata[name]]
+ if "__openllm_migration__" in self._objects:
+ cur_value = self._objects["__openllm_migration__"].get(name, _sentinel)
+ if cur_value is not _sentinel:
+ warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3)
+ return getattr(self, cur_value)
+ if name in self._objects: return self._objects.__getitem__(name)
+ if name in self._modules: value = self._get_module(name)
+ elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name)
+ else: raise AttributeError(f"module {self.__name__} has no attribute {name}")
+ setattr(self, name, value)
+ return value
+
+ def _get_module(self, module_name: str) -> types.ModuleType:
+ try:
+ return importlib.import_module("." + module_name, self.__name__)
+ except Exception as e:
+ raise RuntimeError(f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}") from e
+
+ def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]:
+ """This is to ensure any given module is pickle-able."""
+ return (self.__class__, (self._name, self.__file__, self._import_structure))
diff --git a/src/openllm/utils/representation.py b/src/openllm/utils/representation.py
index 30210853..dfb31810 100644
--- a/src/openllm/utils/representation.py
+++ b/src/openllm/utils/representation.py
@@ -20,53 +20,51 @@ import attr
import orjson
if t.TYPE_CHECKING:
- ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
-
+ ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
class ReprMixin:
- """This class display possible representation of given class.
+ """This class display possible representation of given class.
- It can be used for implementing __rich_pretty__ and __pretty__ methods in the future.
- Most subclass needs to implement a __repr_keys__ property.
+ It can be used for implementing __rich_pretty__ and __pretty__ methods in the future.
+ Most subclass needs to implement a __repr_keys__ property.
- Based on the design from Pydantic.
- The __repr__ will display the json representation of the object for easier interaction.
- The __str__ will display either __attrs_repr__ or __repr_str__.
+ Based on the design from Pydantic.
+ The __repr__ will display the json representation of the object for easier interaction.
+ The __str__ will display either __attrs_repr__ or __repr_str__.
+ """
+ @property
+ @abstractmethod
+ def __repr_keys__(self) -> set[str]:
+ """This can be overriden by base class using this mixin."""
+
+ def __repr__(self) -> str:
+ """The `__repr__` for any subclass of Mixin.
+
+ It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
"""
+ from . import bentoml_cattr
- @property
- @abstractmethod
- def __repr_keys__(self) -> set[str]:
- """This can be overriden by base class using this mixin."""
+ serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}
+ return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}"
- def __repr__(self) -> str:
- """The `__repr__` for any subclass of Mixin.
+ def __str__(self) -> str:
+ """The string representation of the given Mixin subclass.
- It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
- """
- from . import bentoml_cattr
+ It will contains all of the attributes from __repr_keys__
+ """
+ return self.__repr_str__(" ")
- serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}
- return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}"
+ def __repr_name__(self) -> str:
+ """Name of the instance's class, used in __repr__."""
+ return self.__class__.__name__
- def __str__(self) -> str:
- """The string representation of the given Mixin subclass.
+ def __repr_str__(self, join_str: str) -> str:
+ """To be used with __str__."""
+ return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
- It will contains all of the attributes from __repr_keys__
- """
- return self.__repr_str__(" ")
+ def __repr_args__(self) -> ReprArgs:
+ """This can also be overriden by base class using this mixin.
- def __repr_name__(self) -> str:
- """Name of the instance's class, used in __repr__."""
- return self.__class__.__name__
-
- def __repr_str__(self, join_str: str) -> str:
- """To be used with __str__."""
- return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
-
- def __repr_args__(self) -> ReprArgs:
- """This can also be overriden by base class using this mixin.
-
- By default it does a getattr of the current object from __repr_keys__.
- """
- return ((k, getattr(self, k)) for k in self.__repr_keys__)
+ By default it does a getattr of the current object from __repr_keys__.
+ """
+ return ((k, getattr(self, k)) for k in self.__repr_keys__)
diff --git a/src/openllm_client/__init__.py b/src/openllm_client/__init__.py
index 3f1e7f20..35d995c3 100644
--- a/src/openllm_client/__init__.py
+++ b/src/openllm_client/__init__.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""The actual client implementation.
Use ``openllm.client`` instead.
diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py
index dc18eba0..6151c272 100644
--- a/src/openllm_client/runtimes/base.py
+++ b/src/openllm_client/runtimes/base.py
@@ -28,240 +28,316 @@ import openllm
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
- from typing import overload
+ from typing import overload
else:
- from typing_extensions import overload
+ from typing_extensions import overload
if t.TYPE_CHECKING:
- import transformers
- from openllm._types import DictStrAny
- from openllm._types import LiteralRuntime
- class AnnotatedClient(bentoml.client.Client):
- def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: ...
- async def async_health(self) -> t.Any: ...
- def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]: ...
- def metadata_v1(self) -> dict[str, t.Any]: ...
- def embeddings_v1(self) -> t.Sequence[float]: ...
-else: transformers, DictStrAny = openllm.utils.LazyLoader("transformers", globals(), "transformers"), dict
+ import transformers
+ from openllm._types import DictStrAny
+ from openllm._types import LiteralRuntime
+
+ class AnnotatedClient(bentoml.client.Client):
+ def health(self, *args: t.Any, **attrs: t.Any) -> t.Any:
+ ...
+
+ async def async_health(self) -> t.Any:
+ ...
+
+ def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]:
+ ...
+
+ def metadata_v1(self) -> dict[str, t.Any]:
+ ...
+
+ def embeddings_v1(self) -> t.Sequence[float]:
+ ...
+else:
+
+ transformers, DictStrAny = openllm.utils.LazyLoader("transformers", globals(), "transformers"), dict
logger = logging.getLogger(__name__)
def in_async_context() -> bool:
- try:
- _ = asyncio.get_running_loop()
- return True
- except RuntimeError: return False
+ try:
+ _ = asyncio.get_running_loop()
+ return True
+ except RuntimeError:
+ return False
T = t.TypeVar("T")
class ClientMeta(t.Generic[T]):
- _api_version: str
- _client_class: type[bentoml.client.Client]
- _host: str
- _port: str
+ _api_version: str
+ _client_class: type[bentoml.client.Client]
+ _host: str
+ _port: str
- __client__: AnnotatedClient | None = None
- __agent__: transformers.HfAgent | None = None
- __llm__: openllm.LLM[t.Any, t.Any] | None = None
+ __client__: AnnotatedClient | None = None
+ __agent__: transformers.HfAgent | None = None
+ __llm__: openllm.LLM[t.Any, t.Any] | None = None
- def __init__(self, address: str, timeout: int = 30):
- self._address = address
- self._timeout = timeout
- def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
- """Initialise subclass for HTTP and gRPC client type."""
- cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
- cls._api_version = api_version
- @property
- def _hf_agent(self) -> transformers.HfAgent:
- if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.")
- if self.__agent__ is None:
- if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
- self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent"))
- return self.__agent__
- @property
- def _metadata(self) -> T:
- if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
- return self.call("metadata")
- @property
- @abstractmethod
- def model_name(self) -> str: raise NotImplementedError
- @property
- @abstractmethod
- def framework(self) -> LiteralRuntime: raise NotImplementedError
- @property
- @abstractmethod
- def timeout(self) -> int: raise NotImplementedError
- @property
- @abstractmethod
- def model_id(self) -> str: raise NotImplementedError
- @property
- @abstractmethod
- def configuration(self) -> dict[str, t.Any]: raise NotImplementedError
- @property
- @abstractmethod
- def supports_embeddings(self) -> bool: raise NotImplementedError
- @property
- @abstractmethod
- def supports_hf_agent(self) -> bool: raise NotImplementedError
- @property
- def llm(self) -> openllm.LLM[t.Any, t.Any]:
- if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
- return self.__llm__
- @property
- def config(self) -> openllm.LLMConfig: return self.llm.config
- def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
- async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
- @property
- def _cached(self) -> AnnotatedClient:
- if self.__client__ is None:
- self._client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
- self.__client__ = t.cast("AnnotatedClient", self._client_class.from_url(self._address))
- return self.__client__
- @abstractmethod
- def postprocess(self, result: t.Any) -> openllm.GenerationOutput: ...
- @abstractmethod
- def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
+ def __init__(self, address: str, timeout: int = 30):
+ self._address = address
+ self._timeout = timeout
+
+ def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
+ """Initialise subclass for HTTP and gRPC client type."""
+ cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
+ cls._api_version = api_version
+
+ @property
+ def _hf_agent(self) -> transformers.HfAgent:
+ if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.")
+ if self.__agent__ is None:
+ if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
+ self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent"))
+ return self.__agent__
+
+ @property
+ def _metadata(self) -> T:
+ if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
+ return self.call("metadata")
+
+ @property
+ @abstractmethod
+ def model_name(self) -> str:
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def framework(self) -> LiteralRuntime:
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def timeout(self) -> int:
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def model_id(self) -> str:
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def configuration(self) -> dict[str, t.Any]:
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def supports_embeddings(self) -> bool:
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def supports_hf_agent(self) -> bool:
+ raise NotImplementedError
+
+ @property
+ def llm(self) -> openllm.LLM[t.Any, t.Any]:
+ if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
+ return self.__llm__
+
+ @property
+ def config(self) -> openllm.LLMConfig:
+ return self.llm.config
+
+ def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any:
+ return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
+
+ async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any:
+ return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
+
+ @property
+ def _cached(self) -> AnnotatedClient:
+ if self.__client__ is None:
+ self._client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
+ self.__client__ = t.cast("AnnotatedClient", self._client_class.from_url(self._address))
+ return self.__client__
+
+ @abstractmethod
+ def postprocess(self, result: t.Any) -> openllm.GenerationOutput:
+ ...
+
+ @abstractmethod
+ def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
+ ...
class BaseClient(ClientMeta[T]):
- def health(self) -> t.Any: raise NotImplementedError
- def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
- def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
- @overload
- def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
- @overload
- def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
- @overload
- def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
- def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
- return_raw_response = attrs.pop("return_raw_response", None)
- if return_raw_response is not None:
- logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
- if return_raw_response is True: return_response = "raw"
- return_attrs = attrs.pop("return_attrs", None)
- if return_attrs is not None:
- logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
- if return_attrs is True: return_response = "attrs"
- use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
- prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
+ def health(self) -> t.Any:
+ raise NotImplementedError
- inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
- if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/generate"), json=inputs.model_dump(), timeout=self.timeout).json()
- else: result = self.call("generate", inputs.model_dump())
- r = self.postprocess(result)
- if return_response == "attrs": return r
- elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r)
- else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
- # NOTE: Scikit interface
- @overload
- def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
- @overload
- def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
- @overload
- def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
- def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs))
- def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
- if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
- else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
- def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
- if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
- task = kwargs.pop("task", args[0])
- return_code = kwargs.pop("return_code", False)
- remote = kwargs.pop("remote", False)
- try:
- return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs)
- except Exception as err:
- logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err)
- logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address)
+ def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
+ raise NotImplementedError
+ def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
+ raise NotImplementedError
+
+ @overload
+ def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
+ ...
+
+ @overload
+ def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
+ ...
+
+ @overload
+ def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
+ ...
+
+ def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
+ return_raw_response = attrs.pop("return_raw_response", None)
+ if return_raw_response is not None:
+ logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
+ if return_raw_response is True: return_response = "raw"
+ return_attrs = attrs.pop("return_attrs", None)
+ if return_attrs is not None:
+ logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
+ if return_attrs is True: return_response = "attrs"
+ use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
+ prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
+
+ inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
+ if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/generate"), json=inputs.model_dump(), timeout=self.timeout).json()
+ else: result = self.call("generate", inputs.model_dump())
+ r = self.postprocess(result)
+ if return_response == "attrs": return r
+ elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r)
+ else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
+
+ # NOTE: Scikit interface
+ @overload
+ def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
+ ...
+
+ @overload
+ def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
+ ...
+
+ @overload
+ def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
+ ...
+
+ def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
+ return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs))
+
+ def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
+ if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
+ else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
+
+ def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
+ if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
+ task = kwargs.pop("task", args[0])
+ return_code = kwargs.pop("return_code", False)
+ remote = kwargs.pop("remote", False)
+ try:
+ return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs)
+ except Exception as err:
+ logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err)
+ logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address)
class BaseAsyncClient(ClientMeta[T]):
- async def health(self) -> t.Any: raise NotImplementedError
- async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
- async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
- @overload
- async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
- @overload
- async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
- @overload
- async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
- async def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
- return_raw_response = attrs.pop("return_raw_response", None)
- if return_raw_response is not None:
- logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
- if return_raw_response is True: return_response = "raw"
- return_attrs = attrs.pop("return_attrs", None)
- if return_attrs is not None:
- logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
- if return_attrs is True: return_response = "attrs"
- use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
- prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
+ async def health(self) -> t.Any:
+ raise NotImplementedError
- inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
- res = await self.acall("generate", inputs.model_dump())
- r = self.postprocess(res)
+ async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
+ raise NotImplementedError
- if return_response == "attrs": return r
- elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r)
- else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
- # NOTE: Scikit interface
- @overload
- async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
- @overload
- async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
- @overload
- async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
- async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs))
+ async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
+ raise NotImplementedError
- async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
- """Async version of agent.run."""
- if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
- else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
- async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
- if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0")
- if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
- task = kwargs.pop("task", args[0])
- return_code = kwargs.pop("return_code", False)
- remote = kwargs.pop("remote", False)
+ @overload
+ async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
+ ...
- from transformers.tools.agents import clean_code_for_run
- from transformers.tools.agents import get_tool_creation_code
- from transformers.tools.agents import resolve_tools
- from transformers.tools.python_interpreter import evaluate
+ @overload
+ async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
+ ...
- _hf_agent = self._hf_agent
+ @overload
+ async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
+ ...
- prompt = t.cast(str, _hf_agent.format_prompt(task))
- stop = ["Task:"]
- async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client:
- response = await client.post(
- _hf_agent.url_endpoint,
- json={
- "inputs": prompt,
- "parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
- },
- )
- if response.status_code != HTTPStatus.OK:
- raise ValueError(f"Error {response.status_code}: {response.json()}")
+ async def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
+ return_raw_response = attrs.pop("return_raw_response", None)
+ if return_raw_response is not None:
+ logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
+ if return_raw_response is True: return_response = "raw"
+ return_attrs = attrs.pop("return_attrs", None)
+ if return_attrs is not None:
+ logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
+ if return_attrs is True: return_response = "attrs"
+ use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
+ prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
- result = response.json()[0]["generated_text"]
- # Inference API returns the stop sequence
- for stop_seq in stop:
- if result.endswith(stop_seq):
- result = result[: -len(stop_seq)]
- break
+ inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
+ res = await self.acall("generate", inputs.model_dump())
+ r = self.postprocess(res)
- # the below have the same logic as agent.run API
- explanation, code = clean_code_for_run(result)
+ if return_response == "attrs": return r
+ elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r)
+ else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
- _hf_agent.log(f"==Explanation from the agent==\n{explanation}")
+ # NOTE: Scikit interface
+ @overload
+ async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
+ ...
- _hf_agent.log(f"\n\n==Code generated by the agent==\n{code}")
- if not return_code:
- _hf_agent.log("\n\n==Result==")
- _hf_agent.cached_tools = resolve_tools(
- code, _hf_agent.toolbox, remote=remote, cached_tools=_hf_agent.cached_tools
- )
- return evaluate(code, _hf_agent.cached_tools, state=kwargs.copy())
- else:
- tool_code = get_tool_creation_code(code, _hf_agent.toolbox, remote=remote)
- return f"{tool_code}\n{code}"
+ @overload
+ async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
+ ...
+
+ @overload
+ async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
+ ...
+
+ async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
+ return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs))
+
+ async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
+ """Async version of agent.run."""
+ if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
+ else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
+
+ async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
+ if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0")
+ if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
+ task = kwargs.pop("task", args[0])
+ return_code = kwargs.pop("return_code", False)
+ remote = kwargs.pop("remote", False)
+
+ from transformers.tools.agents import clean_code_for_run
+ from transformers.tools.agents import get_tool_creation_code
+ from transformers.tools.agents import resolve_tools
+ from transformers.tools.python_interpreter import evaluate
+
+ _hf_agent = self._hf_agent
+
+ prompt = t.cast(str, _hf_agent.format_prompt(task))
+ stop = ["Task:"]
+ async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client:
+ response = await client.post(_hf_agent.url_endpoint, json={"inputs": prompt, "parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},},)
+ if response.status_code != HTTPStatus.OK:
+ raise ValueError(f"Error {response.status_code}: {response.json()}")
+
+ result = response.json()[0]["generated_text"]
+ # Inference API returns the stop sequence
+ for stop_seq in stop:
+ if result.endswith(stop_seq):
+ result = result[:-len(stop_seq)]
+ break
+
+ # the below have the same logic as agent.run API
+ explanation, code = clean_code_for_run(result)
+
+ _hf_agent.log(f"==Explanation from the agent==\n{explanation}")
+
+ _hf_agent.log(f"\n\n==Code generated by the agent==\n{code}")
+ if not return_code:
+ _hf_agent.log("\n\n==Result==")
+ _hf_agent.cached_tools = resolve_tools(code, _hf_agent.toolbox, remote=remote, cached_tools=_hf_agent.cached_tools)
+ return evaluate(code, _hf_agent.cached_tools, state=kwargs.copy())
+ else:
+ tool_code = get_tool_creation_code(code, _hf_agent.toolbox, remote=remote)
+ return f"{tool_code}\n{code}"
diff --git a/src/openllm_client/runtimes/grpc.py b/src/openllm_client/runtimes/grpc.py
index a859c5d6..7e913f79 100644
--- a/src/openllm_client/runtimes/grpc.py
+++ b/src/openllm_client/runtimes/grpc.py
@@ -25,96 +25,93 @@ from .base import BaseAsyncClient
from .base import BaseClient
if t.TYPE_CHECKING:
- from grpc_health.v1 import health_pb2
+ from grpc_health.v1 import health_pb2
- from bentoml.grpc.v1.service_pb2 import Response
- from openllm._types import LiteralRuntime
+ from bentoml.grpc.v1.service_pb2 import Response
+ from openllm._types import LiteralRuntime
logger = logging.getLogger(__name__)
-
class GrpcClientMixin:
- if t.TYPE_CHECKING:
-
- @property
- def _metadata(self) -> Response:
- ...
+ if t.TYPE_CHECKING:
@property
- def model_name(self) -> str:
- try:
- return self._metadata.json.struct_value.fields["model_name"].string_value
- except KeyError:
- raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+ def _metadata(self) -> Response:
+ ...
- @property
- def framework(self) -> LiteralRuntime:
- try:
- value = self._metadata.json.struct_value.fields["framework"].string_value
- if value not in ("pt", "flax", "tf"):
- raise KeyError
- return value
- except KeyError:
- raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+ @property
+ def model_name(self) -> str:
+ try:
+ return self._metadata.json.struct_value.fields["model_name"].string_value
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def timeout(self) -> int:
- try:
- return int(self._metadata.json.struct_value.fields["timeout"].number_value)
- except KeyError:
- raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+ @property
+ def framework(self) -> LiteralRuntime:
+ try:
+ value = self._metadata.json.struct_value.fields["framework"].string_value
+ if value not in ("pt", "flax", "tf"):
+ raise KeyError
+ return value
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def model_id(self) -> str:
- try:
- return self._metadata.json.struct_value.fields["model_id"].string_value
- except KeyError:
- raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+ @property
+ def timeout(self) -> int:
+ try:
+ return int(self._metadata.json.struct_value.fields["timeout"].number_value)
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def configuration(self) -> dict[str, t.Any]:
- try:
- v = self._metadata.json.struct_value.fields["configuration"].string_value
- return orjson.loads(v)
- except KeyError:
- raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+ @property
+ def model_id(self) -> str:
+ try:
+ return self._metadata.json.struct_value.fields["model_id"].string_value
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def supports_embeddings(self) -> bool:
- try:
- return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
- except KeyError:
- raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+ @property
+ def configuration(self) -> dict[str, t.Any]:
+ try:
+ v = self._metadata.json.struct_value.fields["configuration"].string_value
+ return orjson.loads(v)
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def supports_hf_agent(self) -> bool:
- try:
- return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
- except KeyError:
- raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+ @property
+ def supports_embeddings(self) -> bool:
+ try:
+ return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
- if isinstance(result, dict):
- return openllm.GenerationOutput(**result)
+ @property
+ def supports_hf_agent(self) -> bool:
+ try:
+ return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- from google.protobuf.json_format import MessageToDict
+ def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
+ if isinstance(result, dict):
+ return openllm.GenerationOutput(**result)
- return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
+ from google.protobuf.json_format import MessageToDict
+ return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
class GrpcClient(GrpcClientMixin, BaseClient["Response"], client_type="grpc"):
- def __init__(self, address: str, timeout: int = 30):
- self._host, self._port = address.split(":")
- super().__init__(address, timeout)
-
- def health(self) -> health_pb2.HealthCheckResponse:
- return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
+ def __init__(self, address: str, timeout: int = 30):
+ self._host, self._port = address.split(":")
+ super().__init__(address, timeout)
+ def health(self) -> health_pb2.HealthCheckResponse:
+ return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient["Response"], client_type="grpc"):
- def __init__(self, address: str, timeout: int = 30):
- self._host, self._port = address.split(":")
- super().__init__(address, timeout)
+ def __init__(self, address: str, timeout: int = 30):
+ self._host, self._port = address.split(":")
+ super().__init__(address, timeout)
- async def health(self) -> health_pb2.HealthCheckResponse:
- return await self._cached.health("bentoml.grpc.v1.BentoService")
+ async def health(self) -> health_pb2.HealthCheckResponse:
+ return await self._cached.health("bentoml.grpc.v1.BentoService")
diff --git a/src/openllm_client/runtimes/http.py b/src/openllm_client/runtimes/http.py
index f3e31919..78cf078b 100644
--- a/src/openllm_client/runtimes/http.py
+++ b/src/openllm_client/runtimes/http.py
@@ -28,71 +28,101 @@ from .base import BaseClient
from .base import in_async_context
if t.TYPE_CHECKING:
- from openllm._types import DictStrAny
- from openllm._types import LiteralRuntime
+ from openllm._types import DictStrAny
+ from openllm._types import LiteralRuntime
else:
- DictStrAny = dict
+ DictStrAny = dict
logger = logging.getLogger(__name__)
class HTTPClientMixin:
- if t.TYPE_CHECKING:
- @property
- def _metadata(self) -> DictStrAny: ...
+ if t.TYPE_CHECKING:
+
@property
- def model_name(self) -> str:
- try: return self._metadata["model_name"]
- except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def model_id(self) -> str:
- try: return self._metadata["model_name"]
- except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def framework(self) -> LiteralRuntime:
- try: return self._metadata["framework"]
- except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def timeout(self) -> int:
- try: return self._metadata["timeout"]
- except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def configuration(self) -> dict[str, t.Any]:
- try: return orjson.loads(self._metadata["configuration"])
- except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def supports_embeddings(self) -> bool:
- try: return self._metadata.get("supports_embeddings", False)
- except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- @property
- def supports_hf_agent(self) -> bool:
- try: return self._metadata.get("supports_hf_agent", False)
- except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
- def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput:
- return openllm.GenerationOutput(**result)
+ def _metadata(self) -> DictStrAny:
+ ...
+
+ @property
+ def model_name(self) -> str:
+ try:
+ return self._metadata["model_name"]
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+
+ @property
+ def model_id(self) -> str:
+ try:
+ return self._metadata["model_name"]
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+
+ @property
+ def framework(self) -> LiteralRuntime:
+ try:
+ return self._metadata["framework"]
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+
+ @property
+ def timeout(self) -> int:
+ try:
+ return self._metadata["timeout"]
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+
+ @property
+ def configuration(self) -> dict[str, t.Any]:
+ try:
+ return orjson.loads(self._metadata["configuration"])
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+
+ @property
+ def supports_embeddings(self) -> bool:
+ try:
+ return self._metadata.get("supports_embeddings", False)
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+
+ @property
+ def supports_hf_agent(self) -> bool:
+ try:
+ return self._metadata.get("supports_hf_agent", False)
+ except KeyError:
+ raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
+
+ def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput:
+ return openllm.GenerationOutput(**result)
class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]):
- def __init__(self, address: str, timeout: int = 30):
- address = address if "://" in address else "http://" + address
- self._host, self._port = urlparse(address).netloc.split(":")
- super().__init__(address, timeout)
- def health(self) -> t.Any: return self._cached.health()
- def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
- if not self.supports_embeddings:
- raise ValueError("This model does not support embeddings.")
- if isinstance(prompt, str): prompt = [prompt]
- if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout)
- else: result = self.call("embeddings", list(prompt))
- return openllm.EmbeddingsOutput(**result)
+ def __init__(self, address: str, timeout: int = 30):
+ address = address if "://" in address else "http://" + address
+ self._host, self._port = urlparse(address).netloc.split(":")
+ super().__init__(address, timeout)
+
+ def health(self) -> t.Any:
+ return self._cached.health()
+
+ def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
+ if not self.supports_embeddings:
+ raise ValueError("This model does not support embeddings.")
+ if isinstance(prompt, str): prompt = [prompt]
+ if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout)
+ else: result = self.call("embeddings", list(prompt))
+ return openllm.EmbeddingsOutput(**result)
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]):
- def __init__(self, address: str, timeout: int = 30):
- address = address if "://" in address else "http://" + address
- self._host, self._port = urlparse(address).netloc.split(":")
- super().__init__(address, timeout)
- async def health(self) -> t.Any: return await self._cached.async_health()
- async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
- if not self.supports_embeddings:
- raise ValueError("This model does not support embeddings.")
- if isinstance(prompt, str): prompt = [prompt]
- res = await self.acall("embeddings", list(prompt))
- return openllm.EmbeddingsOutput(**res)
+ def __init__(self, address: str, timeout: int = 30):
+ address = address if "://" in address else "http://" + address
+ self._host, self._port = urlparse(address).netloc.split(":")
+ super().__init__(address, timeout)
+
+ async def health(self) -> t.Any:
+ return await self._cached.async_health()
+
+ async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
+ if not self.supports_embeddings:
+ raise ValueError("This model does not support embeddings.")
+ if isinstance(prompt, str): prompt = [prompt]
+ res = await self.acall("embeddings", list(prompt))
+ return openllm.EmbeddingsOutput(**res)
diff --git a/tests/__init__.py b/tests/__init__.py
index 5e960cb1..62114cc8 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -18,5 +18,4 @@ from hypothesis import settings
settings.register_profile("CI", settings(suppress_health_check=[HealthCheck.too_slow]), deadline=None)
-if "CI" in os.environ:
- settings.load_profile("CI")
+if "CI" in os.environ: settings.load_profile("CI")
diff --git a/tests/_strategies/_configuration.py b/tests/_strategies/_configuration.py
index ab574f69..5cc51585 100644
--- a/tests/_strategies/_configuration.py
+++ b/tests/_strategies/_configuration.py
@@ -25,55 +25,39 @@ logger = logging.getLogger(__name__)
env_strats = st.sampled_from([openllm.utils.EnvVarMixin(model_name) for model_name in openllm.CONFIG_MAPPING.keys()])
-
@st.composite
def model_settings(draw: st.DrawFn):
- """Strategy for generating ModelSettings objects."""
- kwargs: dict[str, t.Any] = {
- "default_id": st.text(min_size=1),
- "model_ids": st.lists(st.text(), min_size=1),
- "architecture": st.text(min_size=1),
- "url": st.text(),
- "requires_gpu": st.booleans(),
- "trust_remote_code": st.booleans(),
- "requirements": st.none() | st.lists(st.text(), min_size=1),
- "default_implementation": st.dictionaries(st.sampled_from(["cpu", "nvidia.com/gpu"]), st.sampled_from(["vllm", "pt", "tf", "flax"])),
- "model_type": st.sampled_from(["causal_lm", "seq2seq_lm"]),
- "runtime": st.sampled_from(["transformers", "ggml"]),
- "name_type": st.sampled_from(["dasherize", "lowercase"]),
- "timeout": st.integers(min_value=3600),
- "workers_per_resource": st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
- }
- return draw(st.builds(ModelSettings, **kwargs))
+ """Strategy for generating ModelSettings objects."""
+ kwargs: dict[str, t.Any] = {
+ "default_id": st.text(min_size=1), "model_ids": st.lists(st.text(), min_size=1), "architecture": st.text(min_size=1), "url": st.text(), "requires_gpu": st.booleans(), "trust_remote_code": st.booleans(), "requirements": st.none()
+ | st.lists(st.text(), min_size=1), "default_implementation": st.dictionaries(st.sampled_from(["cpu", "nvidia.com/gpu"]), st.sampled_from(["vllm", "pt", "tf", "flax"])), "model_type": st.sampled_from(["causal_lm", "seq2seq_lm"]), "runtime": st.sampled_from(["transformers", "ggml"]), "name_type": st.sampled_from(["dasherize", "lowercase"]), "timeout": st.integers(
+ min_value=3600
+ ), "workers_per_resource": st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
+ }
+ return draw(st.builds(ModelSettings, **kwargs))
+def make_llm_config(cls_name: str, dunder_config: dict[str, t.Any] | ModelSettings, fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None, generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None,) -> type[openllm.LLMConfig]:
+ globs: dict[str, t.Any] = {"openllm": openllm}
+ _config_args: list[str] = []
+ lines: list[str] = [f"class {cls_name}Config(openllm.LLMConfig):"]
+ for attr, value in dunder_config.items():
+ _config_args.append(f'"{attr}": __attr_{attr}')
+ globs[f"_{cls_name}Config__attr_{attr}"] = value
+ lines.append(f' __config__ = {{ {", ".join(_config_args)} }}')
+ if fields is not None:
+ for field, type_, default in fields:
+ lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({default!r})")
+ if generation_fields is not None:
+ generation_lines = ["class GenerationConfig:"]
+ for field, default in generation_fields:
+ generation_lines.append(f" {field} = {default!r}")
+ lines.extend((" " + line for line in generation_lines))
-def make_llm_config(
- cls_name: str,
- dunder_config: dict[str, t.Any] | ModelSettings,
- fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None,
- generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None,
-) -> type[openllm.LLMConfig]:
- globs: dict[str, t.Any] = {"openllm": openllm}
- _config_args: list[str] = []
- lines: list[str] = [f"class {cls_name}Config(openllm.LLMConfig):"]
- for attr, value in dunder_config.items():
- _config_args.append(f'"{attr}": __attr_{attr}')
- globs[f"_{cls_name}Config__attr_{attr}"] = value
- lines.append(f' __config__ = {{ {", ".join(_config_args)} }}')
- if fields is not None:
- for field, type_, default in fields:
- lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({default!r})")
- if generation_fields is not None:
- generation_lines = ["class GenerationConfig:"]
- for field, default in generation_fields:
- generation_lines.append(f" {field} = {default!r}")
- lines.extend((" " + line for line in generation_lines))
+ script = "\n".join(lines)
- script = "\n".join(lines)
+ if openllm.utils.DEBUG:
+ logger.info("Generated class %s:\n%s", cls_name, script)
- if openllm.utils.DEBUG:
- logger.info("Generated class %s:\n%s", cls_name, script)
+ eval(compile(script, "name", "exec"), globs)
- eval(compile(script, "name", "exec"), globs)
-
- return globs[f"{cls_name}Config"]
+ return globs[f"{cls_name}Config"]
diff --git a/tests/client_test.py b/tests/client_test.py
index 88650445..df31b722 100644
--- a/tests/client_test.py
+++ b/tests/client_test.py
@@ -17,7 +17,5 @@ from __future__ import annotations
import openllm
def test_import_client():
- assert len(openllm.client.__all__) == 4
- assert all(
- hasattr(openllm.client, attr) for attr in ("AsyncGrpcClient", "GrpcClient", "AsyncHTTPClient", "HTTPClient")
- )
+ assert len(openllm.client.__all__) == 4
+ assert all(hasattr(openllm.client, attr) for attr in ("AsyncGrpcClient", "GrpcClient", "AsyncHTTPClient", "HTTPClient"))
diff --git a/tests/configuration_test.py b/tests/configuration_test.py
index 6a937b9c..5b525009 100644
--- a/tests/configuration_test.py
+++ b/tests/configuration_test.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""All configuration-related tests for openllm.LLMConfig. This will include testing
for ModelEnv construction and parsing environment variables.
"""
@@ -41,213 +40,137 @@ from ._strategies._configuration import model_settings
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
- DictStrAny = dict[str, t.Any]
+ DictStrAny = dict[str, t.Any]
else:
- DictStrAny = dict
-
+ DictStrAny = dict
# XXX: @aarnphm fixes TypedDict behaviour in 3.11
-@pytest.mark.skipif(
- sys.version_info[:2] == (3, 11), reason="TypedDict in 3.11 behaves differently, so we need to fix this"
-)
+@pytest.mark.skipif(sys.version_info[:2] == (3, 11), reason="TypedDict in 3.11 behaves differently, so we need to fix this")
def test_missing_default():
- with pytest.raises(ValueError, match="Missing required fields *"):
- make_llm_config("MissingDefaultId", {"name_type": "lowercase", "requirements": ["bentoml"]})
- with pytest.raises(ValueError, match="Missing required fields *"):
- make_llm_config("MissingModelId", {"default_id": "huggingface/t5-tiny-testing", "requirements": ["bentoml"]})
- with pytest.raises(ValueError, match="Missing required fields *"):
- make_llm_config(
- "MissingArchitecture",
- {
- "default_id": "huggingface/t5-tiny-testing",
- "model_ids": ["huggingface/t5-tiny-testing"],
- "requirements": ["bentoml"],
- },
- )
-
+ with pytest.raises(ValueError, match="Missing required fields *"):
+ make_llm_config("MissingDefaultId", {"name_type": "lowercase", "requirements": ["bentoml"]})
+ with pytest.raises(ValueError, match="Missing required fields *"):
+ make_llm_config("MissingModelId", {"default_id": "huggingface/t5-tiny-testing", "requirements": ["bentoml"]})
+ with pytest.raises(ValueError, match="Missing required fields *"):
+ make_llm_config("MissingArchitecture", {"default_id": "huggingface/t5-tiny-testing", "model_ids": ["huggingface/t5-tiny-testing"], "requirements": ["bentoml"],},)
def test_forbidden_access():
- cl_ = make_llm_config(
- "ForbiddenAccess",
- {
- "default_id": "huggingface/t5-tiny-testing",
- "model_ids": ["huggingface/t5-tiny-testing", "bentoml/t5-tiny-testing"],
- "architecture": "PreTrainedModel",
- "requirements": ["bentoml"],
- },
- )
+ cl_ = make_llm_config("ForbiddenAccess", {"default_id": "huggingface/t5-tiny-testing", "model_ids": ["huggingface/t5-tiny-testing", "bentoml/t5-tiny-testing"], "architecture": "PreTrainedModel", "requirements": ["bentoml"],},)
- assert pytest.raises(
- openllm.exceptions.ForbiddenAttributeError,
- cl_.__getattribute__,
- cl_(),
- "__config__",
- )
- assert pytest.raises(
- openllm.exceptions.ForbiddenAttributeError,
- cl_.__getattribute__,
- cl_(),
- "GenerationConfig",
- )
- assert pytest.raises(
- openllm.exceptions.ForbiddenAttributeError,
- cl_.__getattribute__,
- cl_(),
- "SamplingParams",
- )
-
- assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig)
+ assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "__config__",)
+ assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "GenerationConfig",)
+ assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "SamplingParams",)
+ assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig)
@given(model_settings())
def test_class_normal_gen(gen_settings: ModelSettings):
- assume(gen_settings["default_id"] and all(i for i in gen_settings["model_ids"]))
- cl_: type[openllm.LLMConfig] = make_llm_config("NotFullLLM", gen_settings)
- assert issubclass(cl_, openllm.LLMConfig)
- for key in gen_settings:
- assert object.__getattribute__(cl_, f"__openllm_{key}__") == gen_settings.__getitem__(key)
-
+ assume(gen_settings["default_id"] and all(i for i in gen_settings["model_ids"]))
+ cl_: type[openllm.LLMConfig] = make_llm_config("NotFullLLM", gen_settings)
+ assert issubclass(cl_, openllm.LLMConfig)
+ for key in gen_settings:
+ assert object.__getattribute__(cl_, f"__openllm_{key}__") == gen_settings.__getitem__(key)
@given(model_settings(), st.integers())
def test_simple_struct_dump(gen_settings: ModelSettings, field1: int):
- cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),))
- assert cl_().model_dump()["field1"] == field1
-
+ cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),))
+ assert cl_().model_dump()["field1"] == field1
@given(model_settings(), st.integers())
def test_config_derivation(gen_settings: ModelSettings, field1: int):
- cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),))
- new_cls = cl_.model_derivate("DerivedLLM", default_id="asdfasdf")
- assert new_cls.__openllm_default_id__ == "asdfasdf"
-
+ cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),))
+ new_cls = cl_.model_derivate("DerivedLLM", default_id="asdfasdf")
+ assert new_cls.__openllm_default_id__ == "asdfasdf"
@given(model_settings())
def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
- cl_ = make_llm_config("AttrsProtocolLLM", gen_settings)
- assert attr.has(cl_)
+ cl_ = make_llm_config("AttrsProtocolLLM", gen_settings)
+ assert attr.has(cl_)
+@given(model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),)
+def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float):
+ cl_ = make_llm_config("ComplexLLM", gen_settings, fields=(("field1", "float", field1),), generation_fields=(("temperature", temperature),),)
+ sent = cl_()
+ assert sent.model_dump()["field1"] == field1
+ assert sent.model_dump()["generation_config"]["temperature"] == temperature
+ assert sent.model_dump(flatten=True)["field1"] == field1
+ assert sent.model_dump(flatten=True)["temperature"] == temperature
-@given(
- model_settings(),
- st.integers(max_value=283473),
- st.floats(min_value=0.0, max_value=1.0),
- st.integers(max_value=283473),
- st.floats(min_value=0.0, max_value=1.0),
-)
-def test_complex_struct_dump(
- gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float
-):
- cl_ = make_llm_config(
- "ComplexLLM",
- gen_settings,
- fields=(("field1", "float", field1),),
- generation_fields=(("temperature", temperature),),
- )
- sent = cl_()
- assert sent.model_dump()["field1"] == field1
- assert sent.model_dump()["generation_config"]["temperature"] == temperature
- assert sent.model_dump(flatten=True)["field1"] == field1
- assert sent.model_dump(flatten=True)["temperature"] == temperature
-
- passed = cl_(field1=input_field1, temperature=input_temperature)
- assert passed.model_dump()["field1"] == input_field1
- assert passed.model_dump()["generation_config"]["temperature"] == input_temperature
- assert passed.model_dump(flatten=True)["field1"] == input_field1
- assert passed.model_dump(flatten=True)["temperature"] == input_temperature
-
- pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1)
- assert pas_nested.model_dump()["field1"] == input_field1
- assert pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature
+ passed = cl_(field1=input_field1, temperature=input_temperature)
+ assert passed.model_dump()["field1"] == input_field1
+ assert passed.model_dump()["generation_config"]["temperature"] == input_temperature
+ assert passed.model_dump(flatten=True)["field1"] == input_field1
+ assert passed.model_dump(flatten=True)["temperature"] == input_temperature
+ pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1)
+ assert pas_nested.model_dump()["field1"] == input_field1
+ assert pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature
@contextlib.contextmanager
def patch_env(**attrs: t.Any):
- with mock.patch.dict(os.environ, attrs, clear=True):
- yield
-
+ with mock.patch.dict(os.environ, attrs, clear=True):
+ yield
def test_struct_envvar():
- with patch_env(
- **{
- field_env_key("env_llm", "field1"): "4",
- field_env_key("env_llm", "temperature", suffix="generation"): "0.2",
- }
- ):
+ with patch_env(**{field_env_key("env_llm", "field1"): "4", field_env_key("env_llm", "temperature", suffix="generation"): "0.2",}):
- class EnvLLM(openllm.LLMConfig):
- __config__ = {
- "default_id": "asdfasdf",
- "model_ids": ["asdf", "asdfasdfads"],
- "architecture": "PreTrainedModel",
- }
- field1: int = 2
+ class EnvLLM(openllm.LLMConfig):
+ __config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel",}
+ field1: int = 2
- class GenerationConfig:
- temperature: float = 0.8
+ class GenerationConfig:
+ temperature: float = 0.8
- sent = EnvLLM.model_construct_env()
- assert sent.field1 == 4
- assert sent["temperature"] == 0.2
-
- overwrite_default = EnvLLM()
- assert overwrite_default.field1 == 4
- assert overwrite_default["temperature"] == 0.2
+ sent = EnvLLM.model_construct_env()
+ assert sent.field1 == 4
+ assert sent["temperature"] == 0.2
+ overwrite_default = EnvLLM()
+ assert overwrite_default.field1 == 4
+ assert overwrite_default["temperature"] == 0.2
def test_struct_provided_fields():
- class EnvLLM(openllm.LLMConfig):
- __config__ = {
- "default_id": "asdfasdf",
- "model_ids": ["asdf", "asdfasdfads"],
- "architecture": "PreTrainedModel",
- }
- field1: int = 2
+ class EnvLLM(openllm.LLMConfig):
+ __config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel",}
+ field1: int = 2
- class GenerationConfig:
- temperature: float = 0.8
-
- sent = EnvLLM.model_construct_env(field1=20, temperature=0.4)
- assert sent.field1 == 20
- assert sent.generation_config.temperature == 0.4
+ class GenerationConfig:
+ temperature: float = 0.8
+ sent = EnvLLM.model_construct_env(field1=20, temperature=0.4)
+ assert sent.field1 == 20
+ assert sent.generation_config.temperature == 0.4
def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch):
- with monkeypatch.context() as mk:
- mk.setenv(field_env_key("overwrite_with_env_available", "field1"), str(4.0))
- mk.setenv(field_env_key("overwrite_with_env_available", "temperature", suffix="generation"), str(0.2))
- sent = make_llm_config(
- "OverwriteWithEnvAvailable",
- {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel"},
- fields=(("field1", "float", 3.0),),
- ).model_construct_env(field1=20.0, temperature=0.4)
- assert sent.generation_config.temperature == 0.4
- assert sent.field1 == 20.0
-
+ with monkeypatch.context() as mk:
+ mk.setenv(field_env_key("overwrite_with_env_available", "field1"), str(4.0))
+ mk.setenv(field_env_key("overwrite_with_env_available", "temperature", suffix="generation"), str(0.2))
+ sent = make_llm_config("OverwriteWithEnvAvailable", {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel"}, fields=(("field1", "float", 3.0),),).model_construct_env(field1=20.0, temperature=0.4)
+ assert sent.generation_config.temperature == 0.4
+ assert sent.field1 == 20.0
@given(model_settings())
@pytest.mark.parametrize(("return_dict", "typ"), [(True, DictStrAny), (False, transformers.GenerationConfig)])
def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings):
- cl_ = make_llm_config("ConversionLLM", gen_settings)
- assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ)
-
+ cl_ = make_llm_config("ConversionLLM", gen_settings)
+ assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ)
@given(model_settings())
def test_click_conversion(gen_settings: ModelSettings):
- # currently our conversion omit Union type.
- def cli_mock(**attrs: t.Any):
- return attrs
-
- cl_ = make_llm_config("ClickConversionLLM", gen_settings)
- wrapped = cl_.to_click_options(cli_mock)
- filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union}
- click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith("fake_")]
- assert len(filtered) == len(click_options_filtered)
+ # currently our conversion omit Union type.
+ def cli_mock(**attrs: t.Any):
+ return attrs
+ cl_ = make_llm_config("ClickConversionLLM", gen_settings)
+ wrapped = cl_.to_click_options(cli_mock)
+ filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union}
+ click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith("fake_")]
+ assert len(filtered) == len(click_options_filtered)
@pytest.mark.parametrize("model_name", openllm.CONFIG_MAPPING.keys())
def test_configuration_dict_protocol(model_name: str):
- config = openllm.AutoConfig.for_model(model_name)
- assert isinstance(config.items(), list)
- assert isinstance(config.keys(), list)
- assert isinstance(config.values(), list)
- assert isinstance(dict(config), dict)
+ config = openllm.AutoConfig.for_model(model_name)
+ assert isinstance(config.items(), list)
+ assert isinstance(config.keys(), list)
+ assert isinstance(config.values(), list)
+ assert isinstance(dict(config), dict)
diff --git a/tests/conftest.py b/tests/conftest.py
index 32f79858..d2d96c3a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,53 +22,33 @@ import pytest
import openllm
if t.TYPE_CHECKING:
- from openllm._types import LiteralRuntime
+ from openllm._types import LiteralRuntime
+_FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m", "baichuan": "baichuan-inc/Baichuan-7B",}
+_PROMPT_MAPPING = {"qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?",}
-_FRAMEWORK_MAPPING = {
- "flan_t5": "google/flan-t5-small",
- "opt": "facebook/opt-125m",
- "baichuan": "baichuan-inc/Baichuan-7B",
-}
-_PROMPT_MAPPING = {
- "qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?",
-}
+def parametrise_local_llm(model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
+ if model not in _FRAMEWORK_MAPPING:
+ pytest.skip(f"'{model}' is not yet supported in framework testing.")
+ runtime_impl: tuple[LiteralRuntime, ...] = tuple()
+ if model in openllm.MODEL_MAPPING_NAMES:
+ runtime_impl += ("pt",)
+ if model in openllm.MODEL_FLAX_MAPPING_NAMES:
+ runtime_impl += ("flax",)
+ if model in openllm.MODEL_TF_MAPPING_NAMES:
+ runtime_impl += ("tf",)
-def parametrise_local_llm(
- model: str,
-) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
- if model not in _FRAMEWORK_MAPPING:
- pytest.skip(f"'{model}' is not yet supported in framework testing.")
-
- runtime_impl: tuple[LiteralRuntime, ...] = tuple()
- if model in openllm.MODEL_MAPPING_NAMES:
- runtime_impl += ("pt",)
- if model in openllm.MODEL_FLAX_MAPPING_NAMES:
- runtime_impl += ("flax",)
- if model in openllm.MODEL_TF_MAPPING_NAMES:
- runtime_impl += ("tf",)
-
- for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()):
- llm = openllm.Runner(
- model,
- model_id=_FRAMEWORK_MAPPING[model],
- ensure_available=True,
- implementation=framework,
- init_local=True,
- )
- yield prompt, llm
-
+ for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()):
+ llm = openllm.Runner(model, model_id=_FRAMEWORK_MAPPING[model], ensure_available=True, implementation=framework, init_local=True,)
+ yield prompt, llm
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
- if os.getenv("GITHUB_ACTIONS") is None:
- if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames:
- metafunc.parametrize(
- "prompt,llm", [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]
- )
-
+ if os.getenv("GITHUB_ACTIONS") is None:
+ if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames:
+ metafunc.parametrize("prompt,llm", [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
- # If no tests are collected, pytest exists with code 5, which makes the CI fail.
- if exitstatus == 5:
- session.exitstatus = 0
+ # If no tests are collected, pytest exists with code 5, which makes the CI fail.
+ if exitstatus == 5:
+ session.exitstatus = 0
diff --git a/tests/models/conftest.py b/tests/models/conftest.py
index acd85a30..9c3f9d3a 100644
--- a/tests/models/conftest.py
+++ b/tests/models/conftest.py
@@ -37,277 +37,207 @@ from openllm._llm import normalise_model_name
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
- import subprocess
+ import subprocess
- from openllm_client.runtimes.base import BaseAsyncClient
- from syrupy.assertion import SnapshotAssertion
- from syrupy.types import PropertyFilter
- from syrupy.types import PropertyMatcher
- from syrupy.types import SerializableData
- from syrupy.types import SerializedData
+ from openllm_client.runtimes.base import BaseAsyncClient
+ from syrupy.assertion import SnapshotAssertion
+ from syrupy.types import PropertyFilter
+ from syrupy.types import PropertyMatcher
+ from syrupy.types import SerializableData
+ from syrupy.types import SerializedData
- from openllm._configuration import GenerationConfig
- from openllm._types import DictStrAny
- from openllm._types import ListAny
+ from openllm._configuration import GenerationConfig
+ from openllm._types import DictStrAny
+ from openllm._types import ListAny
else:
- DictStrAny = dict
- ListAny = list
-
+ DictStrAny = dict
+ ListAny = list
class ResponseComparator(JSONSnapshotExtension):
- def serialize(
- self,
- data: SerializableData,
- *,
- exclude: PropertyFilter | None = None,
- matcher: PropertyMatcher | None = None,
- ) -> SerializedData:
- if openllm.utils.LazyType(ListAny).isinstance(data):
- data = [d.unmarshaled for d in data]
- else:
- data = data.unmarshaled
- data = self._filter(data=data, depth=0, path=(), exclude=exclude, matcher=matcher)
- return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
+ def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData:
+ if openllm.utils.LazyType(ListAny).isinstance(data):
+ data = [d.unmarshaled for d in data]
+ else:
+ data = data.unmarshaled
+ data = self._filter(data=data, depth=0, path=(), exclude=exclude, matcher=matcher)
+ return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
- def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
- def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
- try:
- data = orjson.loads(data)
- except orjson.JSONDecodeError as err:
- raise ValueError(f"Failed to decode JSON data: {data}") from err
- if openllm.utils.LazyType(DictStrAny).isinstance(data):
- return openllm.GenerationOutput(**data)
- elif openllm.utils.LazyType(ListAny).isinstance(data):
- return [openllm.GenerationOutput(**d) for d in data]
- else:
- raise NotImplementedError(f"Data {data} has unsupported type.")
+ def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
+ def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
+ try:
+ data = orjson.loads(data)
+ except orjson.JSONDecodeError as err:
+ raise ValueError(f"Failed to decode JSON data: {data}") from err
+ if openllm.utils.LazyType(DictStrAny).isinstance(data):
+ return openllm.GenerationOutput(**data)
+ elif openllm.utils.LazyType(ListAny).isinstance(data):
+ return [openllm.GenerationOutput(**d) for d in data]
+ else:
+ raise NotImplementedError(f"Data {data} has unsupported type.")
- serialized_data = convert_data(serialized_data)
- snapshot_data = convert_data(snapshot_data)
+ serialized_data = convert_data(serialized_data)
+ snapshot_data = convert_data(snapshot_data)
- if openllm.utils.LazyType(ListAny).isinstance(serialized_data):
- serialized_data = [serialized_data]
- if openllm.utils.LazyType(ListAny).isinstance(snapshot_data):
- snapshot_data = [snapshot_data]
+ if openllm.utils.LazyType(ListAny).isinstance(serialized_data):
+ serialized_data = [serialized_data]
+ if openllm.utils.LazyType(ListAny).isinstance(snapshot_data):
+ snapshot_data = [snapshot_data]
- def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool:
- return s == t
+ def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool:
+ return s == t
- def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
- return (
- len(s.responses) == len(t.responses)
- and all([_s == _t for _s, _t in zip(s.responses, t.responses)])
- and eq_config(s.marshaled_config, t.marshaled_config)
- )
-
- return len(serialized_data) == len(snapshot_data) and all(
- [eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)]
- )
+ def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
+ return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
+ return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
@pytest.fixture()
def response_snapshot(snapshot: SnapshotAssertion):
- return snapshot.use_extension(ResponseComparator)
-
+ return snapshot.use_extension(ResponseComparator)
@attr.define(init=False)
class _Handle(ABC):
- port: int
- deployment_mode: t.Literal["container", "local"]
+ port: int
+ deployment_mode: t.Literal["container", "local"]
- client: BaseAsyncClient[t.Any] = attr.field(init=False)
+ client: BaseAsyncClient[t.Any] = attr.field(init=False)
- if t.TYPE_CHECKING:
+ if t.TYPE_CHECKING:
- def __attrs_init__(self, *args: t.Any, **attrs: t.Any):
- ...
+ def __attrs_init__(self, *args: t.Any, **attrs: t.Any):
+ ...
- def __attrs_post_init__(self):
- self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}")
+ def __attrs_post_init__(self):
+ self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}")
- @abstractmethod
- def status(self) -> bool:
- raise NotImplementedError
-
- async def health(self, timeout: int = 240):
- start_time = time.time()
- while time.time() - start_time < timeout:
- if not self.status():
- raise RuntimeError(f"Failed to initialise {self.__class__.__name__}")
- await self.client.health()
- try:
- await self.client.query("sanity")
- return
- except Exception:
- time.sleep(1)
- raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.")
+ @abstractmethod
+ def status(self) -> bool:
+ raise NotImplementedError
+ async def health(self, timeout: int = 240):
+ start_time = time.time()
+ while time.time() - start_time < timeout:
+ if not self.status():
+ raise RuntimeError(f"Failed to initialise {self.__class__.__name__}")
+ await self.client.health()
+ try:
+ await self.client.query("sanity")
+ return
+ except Exception:
+ time.sleep(1)
+ raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.")
@attr.define(init=False)
class LocalHandle(_Handle):
- process: subprocess.Popen[bytes]
+ process: subprocess.Popen[bytes]
- def __init__(
- self,
- process: subprocess.Popen[bytes],
- port: int,
- deployment_mode: t.Literal["container", "local"],
- ):
- self.__attrs_init__(port, deployment_mode, process)
-
- def status(self) -> bool:
- return self.process.poll() is None
+ def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal["container", "local"],):
+ self.__attrs_init__(port, deployment_mode, process)
+ def status(self) -> bool:
+ return self.process.poll() is None
class HandleProtocol(t.Protocol):
- @contextlib.contextmanager
- def __call__(
- *,
- model: str,
- model_id: str,
- image_tag: str,
- quantize: t.AnyStr | None = None,
- ) -> t.Generator[_Handle, None, None]:
- ...
-
+ @contextlib.contextmanager
+ def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]:
+ ...
@attr.define(init=False)
class DockerHandle(_Handle):
- container_name: str
- docker_client: docker.DockerClient
+ container_name: str
+ docker_client: docker.DockerClient
- def __init__(
- self,
- docker_client: docker.DockerClient,
- container_name: str,
- port: int,
- deployment_mode: t.Literal["container", "local"],
- ):
- self.__attrs_init__(port, deployment_mode, container_name, docker_client)
-
- def status(self) -> bool:
- container = self.docker_client.containers.get(self.container_name)
- return container.status in ["running", "created"]
+ def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal["container", "local"],):
+ self.__attrs_init__(port, deployment_mode, container_name, docker_client)
+ def status(self) -> bool:
+ container = self.docker_client.containers.get(self.container_name)
+ return container.status in ["running", "created"]
@contextlib.contextmanager
-def _local_handle(
- model: str,
- model_id: str,
- image_tag: str,
- deployment_mode: t.Literal["container", "local"],
- quantize: t.Literal["int8", "int4", "gptq"] | None = None,
- *,
- _serve_grpc: bool = False,
-):
- with openllm.utils.reserve_free_port() as port:
- pass
+def _local_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,):
+ with openllm.utils.reserve_free_port() as port:
+ pass
- if not _serve_grpc:
- proc = openllm.start(
- model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True
- )
- else:
- proc = openllm.start_grpc(
- model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True
- )
+ if not _serve_grpc:
+ proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True)
+ else:
+ proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True)
- yield LocalHandle(proc, port, deployment_mode)
- proc.terminate()
- proc.wait(60)
+ yield LocalHandle(proc, port, deployment_mode)
+ proc.terminate()
+ proc.wait(60)
- process_output = proc.stdout.read()
- print(process_output, file=sys.stderr)
-
- proc.stdout.close()
- if proc.stderr:
- proc.stderr.close()
+ process_output = proc.stdout.read()
+ print(process_output, file=sys.stderr)
+ proc.stdout.close()
+ if proc.stderr:
+ proc.stderr.close()
@contextlib.contextmanager
-def _container_handle(
- model: str,
- model_id: str,
- image_tag: str,
- deployment_mode: t.Literal["container", "local"],
- quantize: t.Literal["int8", "int4", "gptq"] | None = None,
- *,
- _serve_grpc: bool = False,
-):
- envvar = openllm.utils.EnvVarMixin(model)
-
- with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
- pass
- container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_")
- client = docker.from_env()
- try:
- container = client.containers.get(container_name)
- container.stop()
- container.wait()
- container.remove()
- except docker.errors.NotFound:
- pass
-
- args = ["serve" if not _serve_grpc else "serve-grpc"]
-
- env: DictStrAny = {}
-
- if quantize is not None:
- env[envvar.quantize] = quantize
-
- gpus = openllm.utils.device_count() or -1
- devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None
-
- container = client.containers.run(
- image_tag,
- command=args,
- name=container_name,
- environment=env,
- auto_remove=False,
- detach=True,
- device_requests=devs,
- ports={"3000/tcp": port, "3001/tcp": prom_port},
- )
-
- yield DockerHandle(client, container.name, port, deployment_mode)
-
- try:
- container.stop()
- container.wait()
- except docker.errors.NotFound:
- pass
-
- container_output = container.logs().decode("utf-8")
- print(container_output, file=sys.stderr)
+def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,):
+ envvar = openllm.utils.EnvVarMixin(model)
+ with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
+ pass
+ container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_")
+ client = docker.from_env()
+ try:
+ container = client.containers.get(container_name)
+ container.stop()
+ container.wait()
container.remove()
+ except docker.errors.NotFound:
+ pass
+ args = ["serve" if not _serve_grpc else "serve-grpc"]
+
+ env: DictStrAny = {}
+
+ if quantize is not None:
+ env[envvar.quantize] = quantize
+
+ gpus = openllm.utils.device_count() or -1
+ devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None
+
+ container = client.containers.run(image_tag, command=args, name=container_name, environment=env, auto_remove=False, detach=True, device_requests=devs, ports={"3000/tcp": port, "3001/tcp": prom_port},)
+
+ yield DockerHandle(client, container.name, port, deployment_mode)
+
+ try:
+ container.stop()
+ container.wait()
+ except docker.errors.NotFound:
+ pass
+
+ container_output = container.logs().decode("utf-8")
+ print(container_output, file=sys.stderr)
+
+ container.remove()
@pytest.fixture(scope="session", autouse=True)
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
- stack = contextlib.ExitStack()
- yield stack
- stack.close()
-
+ stack = contextlib.ExitStack()
+ yield stack
+ stack.close()
@pytest.fixture(scope="module")
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
- loop = asyncio.get_event_loop()
- yield loop
- loop.close()
-
+ loop = asyncio.get_event_loop()
+ yield loop
+ loop.close()
@pytest.fixture(params=["container", "local"], scope="session")
def deployment_mode(request: pytest.FixtureRequest) -> str:
- return request.param
-
+ return request.param
@pytest.fixture(scope="module")
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]):
- if deployment_mode == "container":
- return functools.partial(_container_handle, deployment_mode=deployment_mode)
- elif deployment_mode == "local":
- return functools.partial(_local_handle, deployment_mode=deployment_mode)
- else:
- raise ValueError(f"Unknown deployment mode: {deployment_mode}")
+ if deployment_mode == "container":
+ return functools.partial(_container_handle, deployment_mode=deployment_mode)
+ elif deployment_mode == "local":
+ return functools.partial(_local_handle, deployment_mode=deployment_mode)
+ else:
+ raise ValueError(f"Unknown deployment mode: {deployment_mode}")
diff --git a/tests/models/flan_t5_test.py b/tests/models/flan_t5_test.py
index 2e149ad7..37ed2815 100644
--- a/tests/models/flan_t5_test.py
+++ b/tests/models/flan_t5_test.py
@@ -20,40 +20,30 @@ import pytest
import openllm
if t.TYPE_CHECKING:
- import contextlib
-
- from .conftest import HandleProtocol
- from .conftest import ResponseComparator
- from .conftest import _Handle
+ import contextlib
+ from .conftest import HandleProtocol
+ from .conftest import ResponseComparator
+ from .conftest import _Handle
model = "flan_t5"
model_id = "google/flan-t5-small"
-
@pytest.fixture(scope="module")
-def flan_t5_handle(
- handler: HandleProtocol,
- deployment_mode: t.Literal["container", "local"],
- clean_context: contextlib.ExitStack,
-):
- with openllm.testing.prepare(
- model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context
- ) as image_tag:
- with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
- yield handle
-
+def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
+ with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
+ with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
+ yield handle
@pytest.fixture(scope="module")
async def flan_t5(flan_t5_handle: _Handle):
- await flan_t5_handle.health(240)
- return flan_t5_handle.client
-
+ await flan_t5_handle.health(240)
+ return flan_t5_handle.client
@pytest.mark.asyncio()
async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
- client = await flan_t5
- response = await client.query("What is the meaning of life?", max_new_tokens=10, top_p=0.9, return_response="attrs")
+ client = await flan_t5
+ response = await client.query("What is the meaning of life?", max_new_tokens=10, top_p=0.9, return_response="attrs")
- assert response.configuration["generation_config"]["max_new_tokens"] == 10
- assert response == response_snapshot
+ assert response.configuration["generation_config"]["max_new_tokens"] == 10
+ assert response == response_snapshot
diff --git a/tests/models/opt_test.py b/tests/models/opt_test.py
index bfbd66ac..98f99fcd 100644
--- a/tests/models/opt_test.py
+++ b/tests/models/opt_test.py
@@ -19,40 +19,30 @@ import pytest
import openllm
if t.TYPE_CHECKING:
- import contextlib
-
- from .conftest import HandleProtocol
- from .conftest import ResponseComparator
- from .conftest import _Handle
+ import contextlib
+ from .conftest import HandleProtocol
+ from .conftest import ResponseComparator
+ from .conftest import _Handle
model = "opt"
model_id = "facebook/opt-125m"
-
@pytest.fixture(scope="module")
-def opt_125m_handle(
- handler: HandleProtocol,
- deployment_mode: t.Literal["container", "local"],
- clean_context: contextlib.ExitStack,
-):
- with openllm.testing.prepare(
- model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context
- ) as image_tag:
- with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
- yield handle
-
+def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
+ with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
+ with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
+ yield handle
@pytest.fixture(scope="module")
async def opt_125m(opt_125m_handle: _Handle):
- await opt_125m_handle.health(240)
- return opt_125m_handle.client
-
+ await opt_125m_handle.health(240)
+ return opt_125m_handle.client
@pytest.mark.asyncio()
async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
- client = await opt_125m
- response = await client.query("What is Deep learning?", max_new_tokens=20, return_response="attrs")
+ client = await opt_125m
+ response = await client.query("What is Deep learning?", max_new_tokens=20, return_response="attrs")
- assert response.configuration["generation_config"]["max_new_tokens"] == 20
- assert response == response_snapshot
+ assert response.configuration["generation_config"]["max_new_tokens"] == 20
+ assert response == response_snapshot
diff --git a/tests/models_test.py b/tests/models_test.py
index 850f2589..1fbf2362 100644
--- a/tests/models_test.py
+++ b/tests/models_test.py
@@ -19,25 +19,22 @@ import typing as t
import pytest
if t.TYPE_CHECKING:
- import openllm
-
+ import openllm
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI")
def test_flan_t5_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
- assert llm(prompt)
-
- assert llm(prompt, temperature=0.8, top_p=0.23)
+ assert llm(prompt)
+ assert llm(prompt, temperature=0.8, top_p=0.23)
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI")
def test_opt_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
- assert llm(prompt)
-
- assert llm(prompt, temperature=0.9, top_k=8)
+ assert llm(prompt)
+ assert llm(prompt, temperature=0.9, top_k=8)
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI")
def test_baichuan_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
- assert llm(prompt)
+ assert llm(prompt)
- assert llm(prompt, temperature=0.95)
+ assert llm(prompt, temperature=0.95)
diff --git a/tests/package_test.py b/tests/package_test.py
index c8cbe4c7..00c67a48 100644
--- a/tests/package_test.py
+++ b/tests/package_test.py
@@ -23,55 +23,44 @@ import openllm
from bentoml._internal.configuration.containers import BentoMLContainer
if t.TYPE_CHECKING:
- from pathlib import Path
-
+ from pathlib import Path
HF_INTERNAL_T5_TESTING = "hf-internal-testing/tiny-random-t5"
-actions_xfail = functools.partial(
- pytest.mark.xfail,
- condition=os.getenv("GITHUB_ACTIONS") is not None,
- reason="Marking GitHub Actions to xfail due to flakiness and building environment not isolated.",
-)
-
+actions_xfail = functools.partial(pytest.mark.xfail, condition=os.getenv("GITHUB_ACTIONS") is not None, reason="Marking GitHub Actions to xfail due to flakiness and building environment not isolated.",)
@actions_xfail
def test_general_build_with_internal_testing():
- bento_store = BentoMLContainer.bento_store.get()
+ bento_store = BentoMLContainer.bento_store.get()
- llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
- bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
+ llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
+ bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
- assert llm.llm_type == bento.info.labels["_type"]
- assert llm.config["env"]["framework_value"] == bento.info.labels["_framework"]
-
- bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
- assert len(bento_store.list(bento.tag)) == 1
+ assert llm.llm_type == bento.info.labels["_type"]
+ assert llm.config["env"]["framework_value"] == bento.info.labels["_framework"]
+ bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
+ assert len(bento_store.list(bento.tag)) == 1
@actions_xfail
def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
- local_path = tmp_path_factory.mktemp("local_t5")
- llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
+ local_path = tmp_path_factory.mktemp("local_t5")
+ llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
- if llm.bettertransformer:
- llm.__llm_model__ = llm.model.reverse_bettertransformer()
+ if llm.bettertransformer:
+ llm.__llm_model__ = llm.model.reverse_bettertransformer()
- llm.save_pretrained(local_path)
-
- assert openllm.build("flan-t5", model_id=local_path.resolve().__fspath__(), model_version="local")
+ llm.save_pretrained(local_path)
+ assert openllm.build("flan-t5", model_id=local_path.resolve().__fspath__(), model_version="local")
@pytest.fixture()
def dockerfile_template(tmp_path_factory: pytest.TempPathFactory):
- file = tmp_path_factory.mktemp("dockerfiles") / "Dockerfile.template"
- file.write_text(
- "{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}"
- )
- return file
-
+ file = tmp_path_factory.mktemp("dockerfiles") / "Dockerfile.template"
+ file.write_text("{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}")
+ return file
@pytest.mark.usefixtures("dockerfile_template")
@actions_xfail
def test_build_with_custom_dockerfile(dockerfile_template: Path):
- assert openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING, dockerfile_template=str(dockerfile_template))
+ assert openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING, dockerfile_template=str(dockerfile_template))
diff --git a/tests/strategies_test.py b/tests/strategies_test.py
index bde7454e..d1429e57 100644
--- a/tests/strategies_test.py
+++ b/tests/strategies_test.py
@@ -19,7 +19,7 @@ import typing as t
import pytest
if t.TYPE_CHECKING:
- from _pytest.monkeypatch import MonkeyPatch
+ from _pytest.monkeypatch import MonkeyPatch
import bentoml
from bentoml._internal.resource import get_resource
@@ -28,186 +28,162 @@ from openllm._strategies import CascadingResourceStrategy
from openllm._strategies import NvidiaGpuResource
def test_nvidia_gpu_resource_from_env(monkeypatch: pytest.MonkeyPatch):
- with monkeypatch.context() as mcls:
- mcls.setenv("CUDA_VISIBLE_DEVICES", "0,1")
- resource = NvidiaGpuResource.from_system()
- assert len(resource) == 2
- assert resource == ["0", "1"]
- mcls.delenv("CUDA_VISIBLE_DEVICES")
-
+ with monkeypatch.context() as mcls:
+ mcls.setenv("CUDA_VISIBLE_DEVICES", "0,1")
+ resource = NvidiaGpuResource.from_system()
+ assert len(resource) == 2
+ assert resource == ["0", "1"]
+ mcls.delenv("CUDA_VISIBLE_DEVICES")
def test_nvidia_gpu_cutoff_minus(monkeypatch: pytest.MonkeyPatch):
- with monkeypatch.context() as mcls:
- mcls.setenv("CUDA_VISIBLE_DEVICES", "0,2,-1,1")
- resource = NvidiaGpuResource.from_system()
- assert len(resource) == 2
- assert resource == ["0", "2"]
- mcls.delenv("CUDA_VISIBLE_DEVICES")
-
+ with monkeypatch.context() as mcls:
+ mcls.setenv("CUDA_VISIBLE_DEVICES", "0,2,-1,1")
+ resource = NvidiaGpuResource.from_system()
+ assert len(resource) == 2
+ assert resource == ["0", "2"]
+ mcls.delenv("CUDA_VISIBLE_DEVICES")
def test_nvidia_gpu_neg_val(monkeypatch: pytest.MonkeyPatch):
- with monkeypatch.context() as mcls:
- mcls.setenv("CUDA_VISIBLE_DEVICES", "-1")
- resource = NvidiaGpuResource.from_system()
- assert len(resource) == 0
- assert resource == []
- mcls.delenv("CUDA_VISIBLE_DEVICES")
-
+ with monkeypatch.context() as mcls:
+ mcls.setenv("CUDA_VISIBLE_DEVICES", "-1")
+ resource = NvidiaGpuResource.from_system()
+ assert len(resource) == 0
+ assert resource == []
+ mcls.delenv("CUDA_VISIBLE_DEVICES")
def test_nvidia_gpu_parse_literal(monkeypatch: pytest.MonkeyPatch):
- with monkeypatch.context() as mcls:
- mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43-ac33420d4628")
- resource = NvidiaGpuResource.from_system()
- assert len(resource) == 1
- assert resource == ["GPU-5ebe9f43-ac33420d4628"]
- mcls.delenv("CUDA_VISIBLE_DEVICES")
- with monkeypatch.context() as mcls:
- mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,GPU-ac33420d4628")
- resource = NvidiaGpuResource.from_system()
- assert len(resource) == 2
- assert resource == ["GPU-5ebe9f43", "GPU-ac33420d4628"]
- mcls.delenv("CUDA_VISIBLE_DEVICES")
- with monkeypatch.context() as mcls:
- mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,-1,GPU-ac33420d4628")
- resource = NvidiaGpuResource.from_system()
- assert len(resource) == 1
- assert resource == ["GPU-5ebe9f43"]
- mcls.delenv("CUDA_VISIBLE_DEVICES")
- with monkeypatch.context() as mcls:
- mcls.setenv("CUDA_VISIBLE_DEVICES", "MIG-GPU-5ebe9f43-ac33420d4628")
- resource = NvidiaGpuResource.from_system()
- assert len(resource) == 1
- assert resource == ["MIG-GPU-5ebe9f43-ac33420d4628"]
- mcls.delenv("CUDA_VISIBLE_DEVICES")
-
+ with monkeypatch.context() as mcls:
+ mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43-ac33420d4628")
+ resource = NvidiaGpuResource.from_system()
+ assert len(resource) == 1
+ assert resource == ["GPU-5ebe9f43-ac33420d4628"]
+ mcls.delenv("CUDA_VISIBLE_DEVICES")
+ with monkeypatch.context() as mcls:
+ mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,GPU-ac33420d4628")
+ resource = NvidiaGpuResource.from_system()
+ assert len(resource) == 2
+ assert resource == ["GPU-5ebe9f43", "GPU-ac33420d4628"]
+ mcls.delenv("CUDA_VISIBLE_DEVICES")
+ with monkeypatch.context() as mcls:
+ mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,-1,GPU-ac33420d4628")
+ resource = NvidiaGpuResource.from_system()
+ assert len(resource) == 1
+ assert resource == ["GPU-5ebe9f43"]
+ mcls.delenv("CUDA_VISIBLE_DEVICES")
+ with monkeypatch.context() as mcls:
+ mcls.setenv("CUDA_VISIBLE_DEVICES", "MIG-GPU-5ebe9f43-ac33420d4628")
+ resource = NvidiaGpuResource.from_system()
+ assert len(resource) == 1
+ assert resource == ["MIG-GPU-5ebe9f43-ac33420d4628"]
+ mcls.delenv("CUDA_VISIBLE_DEVICES")
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="skip GPUs test on CI")
def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
- with monkeypatch.context() as mcls:
- # to make this tests works with system that has GPU
- mcls.setenv("CUDA_VISIBLE_DEVICES", "")
- assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests
-
- assert pytest.raises(
- ValueError,
- NvidiaGpuResource.validate,
- [*NvidiaGpuResource.from_system(), 1],
- ).match("Input list should be all string type.")
- assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match(
- "Input list should be all string type."
- )
- assert pytest.raises(ValueError, NvidiaGpuResource.validate, ["GPU-5ebe9f43", "GPU-ac33420d4628"]).match(
- "Failed to parse available GPUs UUID"
- )
+ with monkeypatch.context() as mcls:
+ # to make this tests works with system that has GPU
+ mcls.setenv("CUDA_VISIBLE_DEVICES", "")
+ assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests
+ assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],).match("Input list should be all string type.")
+ assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match("Input list should be all string type.")
+ assert pytest.raises(ValueError, NvidiaGpuResource.validate, ["GPU-5ebe9f43", "GPU-ac33420d4628"]).match("Failed to parse available GPUs UUID")
def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
- with monkeypatch.context() as mcls:
- # to make this tests works with system that has GPU
- mcls.setenv("CUDA_VISIBLE_DEVICES", "")
- assert NvidiaGpuResource.from_spec(1) == ["0"]
- assert NvidiaGpuResource.from_spec("5") == ["0", "1", "2", "3", "4"]
- assert NvidiaGpuResource.from_spec(1) == ["0"]
- assert NvidiaGpuResource.from_spec(2) == ["0", "1"]
- assert NvidiaGpuResource.from_spec("3") == ["0", "1", "2"]
- assert NvidiaGpuResource.from_spec([1, 3]) == ["1", "3"]
- assert NvidiaGpuResource.from_spec(["1", "3"]) == ["1", "3"]
- assert NvidiaGpuResource.from_spec(-1) == []
- assert NvidiaGpuResource.from_spec("-1") == []
- assert NvidiaGpuResource.from_spec("") == []
- assert NvidiaGpuResource.from_spec("-2") == []
- assert NvidiaGpuResource.from_spec("GPU-288347ab") == ["GPU-288347ab"]
- assert NvidiaGpuResource.from_spec("GPU-288347ab,-1,GPU-ac33420d4628") == ["GPU-288347ab"]
- assert NvidiaGpuResource.from_spec("GPU-288347ab,GPU-ac33420d4628") == ["GPU-288347ab", "GPU-ac33420d4628"]
- assert NvidiaGpuResource.from_spec("MIG-GPU-288347ab") == ["MIG-GPU-288347ab"]
-
- with pytest.raises(TypeError):
- NvidiaGpuResource.from_spec((1, 2, 3))
- with pytest.raises(TypeError):
- NvidiaGpuResource.from_spec(1.5)
- with pytest.raises(ValueError):
- assert NvidiaGpuResource.from_spec(-2)
+ with monkeypatch.context() as mcls:
+ # to make this tests works with system that has GPU
+ mcls.setenv("CUDA_VISIBLE_DEVICES", "")
+ assert NvidiaGpuResource.from_spec(1) == ["0"]
+ assert NvidiaGpuResource.from_spec("5") == ["0", "1", "2", "3", "4"]
+ assert NvidiaGpuResource.from_spec(1) == ["0"]
+ assert NvidiaGpuResource.from_spec(2) == ["0", "1"]
+ assert NvidiaGpuResource.from_spec("3") == ["0", "1", "2"]
+ assert NvidiaGpuResource.from_spec([1, 3]) == ["1", "3"]
+ assert NvidiaGpuResource.from_spec(["1", "3"]) == ["1", "3"]
+ assert NvidiaGpuResource.from_spec(-1) == []
+ assert NvidiaGpuResource.from_spec("-1") == []
+ assert NvidiaGpuResource.from_spec("") == []
+ assert NvidiaGpuResource.from_spec("-2") == []
+ assert NvidiaGpuResource.from_spec("GPU-288347ab") == ["GPU-288347ab"]
+ assert NvidiaGpuResource.from_spec("GPU-288347ab,-1,GPU-ac33420d4628") == ["GPU-288347ab"]
+ assert NvidiaGpuResource.from_spec("GPU-288347ab,GPU-ac33420d4628") == ["GPU-288347ab", "GPU-ac33420d4628"]
+ assert NvidiaGpuResource.from_spec("MIG-GPU-288347ab") == ["MIG-GPU-288347ab"]
+ with pytest.raises(TypeError):
+ NvidiaGpuResource.from_spec((1, 2, 3))
+ with pytest.raises(TypeError):
+ NvidiaGpuResource.from_spec(1.5)
+ with pytest.raises(ValueError):
+ assert NvidiaGpuResource.from_spec(-2)
class GPURunnable(bentoml.Runnable):
- SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu")
-
+ SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu")
def unvalidated_get_resource(x: dict[str, t.Any], y: str, validate: bool = False):
- return get_resource(x, y, validate=validate)
-
+ return get_resource(x, y, validate=validate)
@pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"])
def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str):
- monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
- assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 1) == 2
- assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 2) == 4
- assert pytest.raises(
- ValueError,
- CascadingResourceStrategy.get_worker_count,
- GPURunnable,
- {gpu_type: 0},
- 1,
- ).match("No known supported resource available for *")
- assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 1) == 2
- assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 2) == 4
-
- assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 0.5) == 1
- assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 9]}, 0.5) == 2
- assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5) == 2
- assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 5, 7, 8, 9]}, 0.4) == 2
+ monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
+ assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 1) == 2
+ assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 2) == 4
+ assert pytest.raises(ValueError, CascadingResourceStrategy.get_worker_count, GPURunnable, {gpu_type: 0}, 1,).match("No known supported resource available for *")
+ assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 1) == 2
+ assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 2) == 4
+ assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 0.5) == 1
+ assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 9]}, 0.5) == 2
+ assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5) == 2
+ assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 5, 7, 8, 9]}, 0.4) == 2
@pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"])
def test_cascade_strategy_worker_env(monkeypatch: MonkeyPatch, gpu_type: str):
- monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
+ monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "1"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 1, 1)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "7"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "1"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 1, 1)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "7"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 0)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 1)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 2)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "1"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 1)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "2"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 2)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "7"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 0)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 1)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 2)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "1"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 1)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "2"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 2)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "7"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 0.5, 0)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 0.5, 0)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 0)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 1)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "8,9"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.25, 0)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7,8,9"
-
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 0)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "2,6"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 1)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "7,8"
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 2)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "9"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 0)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 1)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "8,9"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.25, 0)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7,8,9"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 0)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "2,6"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 1)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "7,8"
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 2)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "9"
@pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"])
def test_cascade_strategy_disabled_via_env(monkeypatch: MonkeyPatch, gpu_type: str):
- monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
+ monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
- monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0)
- assert envs.get("CUDA_VISIBLE_DEVICES") == ""
- monkeypatch.delenv("CUDA_VISIBLE_DEVICES")
+ monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == ""
+ monkeypatch.delenv("CUDA_VISIBLE_DEVICES")
- monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "-1")
- envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1)
- assert envs.get("CUDA_VISIBLE_DEVICES") == "-1"
- monkeypatch.delenv("CUDA_VISIBLE_DEVICES")
+ monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "-1")
+ envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1)
+ assert envs.get("CUDA_VISIBLE_DEVICES") == "-1"
+ monkeypatch.delenv("CUDA_VISIBLE_DEVICES")
diff --git a/tools/assert-model-table-latest b/tools/assert-model-table-latest
index af19d26e..5eb99b7c 100755
--- a/tools/assert-model-table-latest
+++ b/tools/assert-model-table-latest
@@ -8,32 +8,23 @@ import sys
from markdown_it import MarkdownIt
-
md = MarkdownIt()
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
with open(os.path.join(ROOT, "README.md"), "r") as f:
- readme = md.parse(f.read())
+ readme = md.parse(f.read())
# NOTE: Currently, we only have one table in README, which is the Model readme.
table = [r for r in readme if r.type == "html_block" and r.content.startswith(" dict[int, str]:
- return {
- v: status
- for v, status in zip(
- range(1, 8),
- [
- "1 - Planning",
- "2 - Pre-Alpha",
- "3 - Alpha",
- "4 - Beta",
- "5 - Production/Stable",
- "6 - Mature",
- "7 - Inactive",
- ],
- )
- }
+ @staticmethod
+ def status() -> dict[int, str]:
+ return {v: status for v, status in zip(range(1, 8), ["1 - Planning", "2 - Pre-Alpha", "3 - Alpha", "4 - Beta", "5 - Production/Stable", "6 - Mature", "7 - Inactive",],)}
- @staticmethod
- def apache() -> str:
- return Classifier.create_classifier("license", "OSI Approved", "Apache Software License")
+ @staticmethod
+ def apache() -> str:
+ return Classifier.create_classifier("license", "OSI Approved", "Apache Software License")
- @staticmethod
- def create_classifier(identifier: str, *decls: t.Any) -> str:
- cls_ = Classifier()
- if identifier not in cls_.identifier:
- raise ValueError(f"{identifier} is not yet supported (supported alias: {Classifier.identifier})")
- return cls_.joiner.join([cls_.identifier[identifier], *decls])
+ @staticmethod
+ def create_classifier(identifier: str, *decls: t.Any) -> str:
+ cls_ = Classifier()
+ if identifier not in cls_.identifier:
+ raise ValueError(f"{identifier} is not yet supported (supported alias: {Classifier.identifier})")
+ return cls_.joiner.join([cls_.identifier[identifier], *decls])
- @staticmethod
- def create_python_classifier(
- implementation: list[str] | None = None, supported_version: list[str] | None = None
- ) -> list[str]:
- if supported_version is None:
- supported_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
- if implementation is None:
- implementation = ["CPython", "PyPy"]
- base = [
- Classifier.create_classifier("language", "Python"),
- Classifier.create_classifier("language", "Python", "3"),
- ]
- base.append(Classifier.create_classifier("language", "Python", "3", "Only"))
- base.extend([Classifier.create_classifier("language", "Python", version) for version in supported_version])
- base.extend(
- [Classifier.create_classifier("language", "Python", "Implementation", impl) for impl in implementation]
- )
- return base
-
- @staticmethod
- def create_status_classifier(level: int) -> str:
- return Classifier.create_classifier("status", Classifier.status()[level])
+ @staticmethod
+ def create_python_classifier(implementation: list[str] | None = None, supported_version: list[str] | None = None) -> list[str]:
+ if supported_version is None:
+ supported_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
+ if implementation is None:
+ implementation = ["CPython", "PyPy"]
+ base = [Classifier.create_classifier("language", "Python"), Classifier.create_classifier("language", "Python", "3"),]
+ base.append(Classifier.create_classifier("language", "Python", "3", "Only"))
+ base.extend([Classifier.create_classifier("language", "Python", version) for version in supported_version])
+ base.extend([Classifier.create_classifier("language", "Python", "Implementation", impl) for impl in implementation])
+ return base
+ @staticmethod
+ def create_status_classifier(level: int) -> str:
+ return Classifier.create_classifier("status", Classifier.status()[level])
@dataclasses.dataclass(frozen=True)
class Dependencies:
- name: str
- git_repo_url: t.Optional[str] = None
- branch: t.Optional[str] = None
- extensions: t.Optional[t.List[str]] = None
- subdirectory: t.Optional[str] = None
- requires_gpu: bool = False
- lower_constraint: t.Optional[str] = None
- upper_constraint: t.Optional[str] = None
- platform: t.Optional[t.Tuple[t.Literal["Linux", "Windows", "Darwin"], t.Literal["eq", "ne"]]] = None
+ name: str
+ git_repo_url: t.Optional[str] = None
+ branch: t.Optional[str] = None
+ extensions: t.Optional[t.List[str]] = None
+ subdirectory: t.Optional[str] = None
+ requires_gpu: bool = False
+ lower_constraint: t.Optional[str] = None
+ upper_constraint: t.Optional[str] = None
+ platform: t.Optional[t.Tuple[t.Literal["Linux", "Windows", "Darwin"], t.Literal["eq", "ne"]]] = None
- def with_options(self, **kwargs: t.Any) -> Dependencies:
- return dataclasses.replace(self, **kwargs)
+ def with_options(self, **kwargs: t.Any) -> Dependencies:
+ return dataclasses.replace(self, **kwargs)
- @property
- def has_constraint(self) -> bool:
- return self.lower_constraint is not None or self.upper_constraint is not None
+ @property
+ def has_constraint(self) -> bool:
+ return self.lower_constraint is not None or self.upper_constraint is not None
- @property
- def pypi_extensions(self) -> str:
- return "" if self.extensions is None else f"[{','.join(self.extensions)}]"
+ @property
+ def pypi_extensions(self) -> str:
+ return "" if self.extensions is None else f"[{','.join(self.extensions)}]"
- @staticmethod
- def platform_restriction(platform: t.LiteralString, op: t.Literal["eq", "ne"] = "eq") -> str:
- return f'platform_system{"==" if op == "eq" else "!="}"{platform}"'
+ @staticmethod
+ def platform_restriction(platform: t.LiteralString, op: t.Literal["eq", "ne"] = "eq") -> str:
+ return f'platform_system{"==" if op == "eq" else "!="}"{platform}"'
- def to_str(self) -> str:
- deps: list[str] = []
- if self.lower_constraint is not None and self.upper_constraint is not None:
- dep = f"{self.name}{self.pypi_extensions}>={self.lower_constraint},<{self.upper_constraint}"
- elif self.lower_constraint is not None:
- dep = f"{self.name}{self.pypi_extensions}>={self.lower_constraint}"
- elif self.upper_constraint is not None:
- dep = f"{self.name}{self.pypi_extensions}<{self.upper_constraint}"
- elif self.subdirectory is not None:
- dep = f"{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git#subdirectory={self.subdirectory}"
- elif self.branch is not None:
- dep = f"{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git@{self.branch}"
- else:
- dep = f"{self.name}{self.pypi_extensions}"
+ def to_str(self) -> str:
+ deps: list[str] = []
+ if self.lower_constraint is not None and self.upper_constraint is not None:
+ dep = f"{self.name}{self.pypi_extensions}>={self.lower_constraint},<{self.upper_constraint}"
+ elif self.lower_constraint is not None:
+ dep = f"{self.name}{self.pypi_extensions}>={self.lower_constraint}"
+ elif self.upper_constraint is not None:
+ dep = f"{self.name}{self.pypi_extensions}<{self.upper_constraint}"
+ elif self.subdirectory is not None:
+ dep = f"{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git#subdirectory={self.subdirectory}"
+ elif self.branch is not None:
+ dep = f"{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git@{self.branch}"
+ else:
+ dep = f"{self.name}{self.pypi_extensions}"
- deps.append(dep)
+ deps.append(dep)
- if self.platform:
- deps.append(self.platform_restriction(*self.platform))
+ if self.platform:
+ deps.append(self.platform_restriction(*self.platform))
- return ";".join(deps)
-
- @classmethod
- def from_tuple(cls, *decls: t.Any) -> Dependencies:
- return cls(*decls)
+ return ";".join(deps)
+ @classmethod
+ def from_tuple(cls, *decls: t.Any) -> Dependencies:
+ return cls(*decls)
_BENTOML_EXT = ["grpc", "io"]
_TRANSFORMERS_EXT = ["torch", "tokenizers", "accelerate"]
@@ -179,14 +142,8 @@ _BASE_DEPENDENCIES = [
]
_NIGHTLY_MAPPING: dict[str, Dependencies] = {
- "bentoml": Dependencies.from_tuple("bentoml", "bentoml/bentoml", "main", _BENTOML_EXT),
- "peft": Dependencies.from_tuple("peft", "huggingface/peft", "main", None),
- "transformers": Dependencies.from_tuple("transformers", "huggingface/transformers", "main", _TRANSFORMERS_EXT),
- "optimum": Dependencies.from_tuple("optimum", "huggingface/optimum", "main", None),
- "accelerate": Dependencies.from_tuple("accelerate", "huggingface/accelerate", "main", None),
- "bitsandbytes": Dependencies.from_tuple("bitsandbytes", "TimDettmers/bitsandbytes", "main", None),
- "trl": Dependencies.from_tuple("trl", "lvwerra/trl", "main", None),
- "vllm": Dependencies.from_tuple("vllm", "vllm-project/vllm", "main", None, None, True, None),
+ "bentoml": Dependencies.from_tuple("bentoml", "bentoml/bentoml", "main", _BENTOML_EXT), "peft": Dependencies.from_tuple("peft", "huggingface/peft", "main", None), "transformers": Dependencies.from_tuple("transformers", "huggingface/transformers", "main", _TRANSFORMERS_EXT), "optimum": Dependencies.from_tuple("optimum", "huggingface/optimum", "main", None),
+ "accelerate": Dependencies.from_tuple("accelerate", "huggingface/accelerate", "main", None), "bitsandbytes": Dependencies.from_tuple("bitsandbytes", "TimDettmers/bitsandbytes", "main", None), "trl": Dependencies.from_tuple("trl", "lvwerra/trl", "main", None), "vllm": Dependencies.from_tuple("vllm", "vllm-project/vllm", "main", None, None, True, None),
}
_ALL_RUNTIME_DEPS = ["flax", "jax", "jaxlib", "tensorflow", "keras"]
@@ -200,114 +157,91 @@ GGML_DEPS = ["ctransformers"]
GPTQ_DEPS = ["auto-gptq[triton]"]
VLLM_DEPS = ["vllm", "ray"]
-_base_requirements: dict[str, t.Any]= {
- inflection.dasherize(name): config_cls.__openllm_requirements__
- for name, config_cls in openllm.CONFIG_MAPPING.items()
- if config_cls.__openllm_requirements__
-}
+_base_requirements: dict[str, t.Any] = {inflection.dasherize(name): config_cls.__openllm_requirements__ for name, config_cls in openllm.CONFIG_MAPPING.items() if config_cls.__openllm_requirements__}
# shallow copy from locals()
_locals = locals().copy()
# NOTE: update this table when adding new external dependencies
# sync with openllm.utils.OPTIONAL_DEPENDENCIES
-_base_requirements.update(
- {v: _locals.get(f"{inflection.underscore(v).upper()}_DEPS") for v in openllm.utils.OPTIONAL_DEPENDENCIES}
-)
+_base_requirements.update({v: _locals.get(f"{inflection.underscore(v).upper()}_DEPS") for v in openllm.utils.OPTIONAL_DEPENDENCIES})
_base_requirements = {k: v for k, v in sorted(_base_requirements.items())}
fname = f"{os.path.basename(os.path.dirname(__file__))}/{os.path.basename(__file__)}"
-
def create_classifiers() -> Array:
- arr = tomlkit.array()
- arr.extend(
- [
- Classifier.create_status_classifier(5),
- Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA"),
- Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "12"),
- Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "11.8"),
- Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "11.7"),
- Classifier.apache(),
- Classifier.create_classifier("topic", "Scientific/Engineering", "Artificial Intelligence"),
- Classifier.create_classifier("topic", "Software Development", "Libraries"),
- Classifier.create_classifier("os", "OS Independent"),
- Classifier.create_classifier("audience", "Developers"),
- Classifier.create_classifier("audience", "Science/Research"),
- Classifier.create_classifier("audience", "System Administrators"),
- Classifier.create_classifier("typing", "Typed"),
- *Classifier.create_python_classifier(),
- ]
- )
- return arr.multiline(True)
-
+ arr = tomlkit.array()
+ arr.extend([
+ Classifier.create_status_classifier(5),
+ Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA"),
+ Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "12"),
+ Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "11.8"),
+ Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "11.7"),
+ Classifier.apache(),
+ Classifier.create_classifier("topic", "Scientific/Engineering", "Artificial Intelligence"),
+ Classifier.create_classifier("topic", "Software Development", "Libraries"),
+ Classifier.create_classifier("os", "OS Independent"),
+ Classifier.create_classifier("audience", "Developers"),
+ Classifier.create_classifier("audience", "Science/Research"),
+ Classifier.create_classifier("audience", "System Administrators"),
+ Classifier.create_classifier("typing", "Typed"), *Classifier.create_python_classifier(),
+ ])
+ return arr.multiline(True)
def create_optional_table() -> Table:
- all_array = tomlkit.array()
- all_array.extend([f"openllm[{k}]" for k in _base_requirements])
+ all_array = tomlkit.array()
+ all_array.extend([f"openllm[{k}]" for k in _base_requirements])
- table = tomlkit.table(is_super_table=True)
- _base_requirements.update({"all": all_array.multiline(True)})
- table.update({k: v for k, v in sorted(_base_requirements.items())})
- table.add(tomlkit.nl())
-
- return table
+ table = tomlkit.table(is_super_table=True)
+ _base_requirements.update({"all": all_array.multiline(True)})
+ table.update({k: v for k, v in sorted(_base_requirements.items())})
+ table.add(tomlkit.nl())
+ return table
def create_url_table() -> Table:
- table = tomlkit.table()
- _urls = {
- "Blog": "https://modelserving.com",
- "Chat": "https://discord.gg/openllm",
- "Documentation": "https://github.com/bentoml/openllm#readme",
- "GitHub": "https://github.com/bentoml/openllm",
- "History": "https://github.com/bentoml/openllm/blob/main/CHANGELOG.md",
- "Homepage": "https://bentoml.com",
- "Tracker": "https://github.com/bentoml/openllm/issues",
- "Twitter": "https://twitter.com/bentomlai",
- }
- table.update({k: v for k, v in sorted(_urls.items())})
- return table
+ table = tomlkit.table()
+ _urls = {
+ "Blog": "https://modelserving.com", "Chat": "https://discord.gg/openllm", "Documentation": "https://github.com/bentoml/openllm#readme", "GitHub": "https://github.com/bentoml/openllm", "History": "https://github.com/bentoml/openllm/blob/main/CHANGELOG.md", "Homepage": "https://bentoml.com", "Tracker": "https://github.com/bentoml/openllm/issues",
+ "Twitter": "https://twitter.com/bentomlai",
+ }
+ table.update({k: v for k, v in sorted(_urls.items())})
+ return table
def build_cli_extensions() -> Table:
- table = tomlkit.table()
- ext: dict[str, str] = {"openllm": "openllm.cli.entrypoint:cli"}
- ext.update({f"openllm-{inflection.dasherize(ke)}": f"openllm.cli.ext.{ke}:cli" for ke in sorted([fname[:-3]
- for fname in os.listdir(os.path.abspath(os.path.join(ROOT, "src", "openllm", "cli", "ext")))
- if fname.endswith(".py") and not fname.startswith("__")])})
- table.update(ext)
- return table
+ table = tomlkit.table()
+ ext: dict[str, str] = {"openllm": "openllm.cli.entrypoint:cli"}
+ ext.update({f"openllm-{inflection.dasherize(ke)}": f"openllm.cli.ext.{ke}:cli" for ke in sorted([fname[:-3] for fname in os.listdir(os.path.abspath(os.path.join(ROOT, "src", "openllm", "cli", "ext"))) if fname.endswith(".py") and not fname.startswith("__")])})
+ table.update(ext)
+ return table
def main() -> int:
- with open(os.path.join(ROOT, "pyproject.toml"), "r") as f:
- pyproject = tomlkit.parse(f.read())
+ with open(os.path.join(ROOT, "pyproject.toml"), "r") as f:
+ pyproject = tomlkit.parse(f.read())
- dependencies_array = tomlkit.array()
- dependencies_array.extend([v.to_str() for v in _BASE_DEPENDENCIES])
+ dependencies_array = tomlkit.array()
+ dependencies_array.extend([v.to_str() for v in _BASE_DEPENDENCIES])
- pyproject["project"]["urls"] = create_url_table()
- pyproject["project"]["scripts"] = build_cli_extensions()
- pyproject["project"]["classifiers"] = create_classifiers()
- pyproject["project"]["optional-dependencies"] = create_optional_table()
- pyproject["project"]["dependencies"] = dependencies_array.multiline(True)
+ pyproject["project"]["urls"] = create_url_table()
+ pyproject["project"]["scripts"] = build_cli_extensions()
+ pyproject["project"]["classifiers"] = create_classifiers()
+ pyproject["project"]["optional-dependencies"] = create_optional_table()
+ pyproject["project"]["dependencies"] = dependencies_array.multiline(True)
- with open(os.path.join(ROOT, "pyproject.toml"), "w") as f:
- f.write(tomlkit.dumps(pyproject))
+ with open(os.path.join(ROOT, "pyproject.toml"), "w") as f:
+ f.write(tomlkit.dumps(pyproject))
- with open(os.path.join(ROOT, "nightly-requirements.txt"), "w") as f:
- f.write(f"# This file is generated by `{fname}`. DO NOT EDIT\n-e .[playground,flan-t5]\n")
- f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values() if not v.requires_gpu])
- with open(os.path.join(ROOT, "nightly-requirements-gpu.txt"), "w") as f:
- f.write(f"# This file is generated by `{fname}`. # DO NOT EDIT\n")
- f.write(
- "# For Jax, Flax, Tensorflow, PyTorch CUDA support, please refers to their official installation for your specific setup.\n"
- )
- f.write("-r nightly-requirements.txt\n-e .[all]\n")
- f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values() if v.requires_gpu])
-
- return 0
+ with open(os.path.join(ROOT, "nightly-requirements.txt"), "w") as f:
+ f.write(f"# This file is generated by `{fname}`. DO NOT EDIT\n-e .[playground,flan-t5]\n")
+ f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values() if not v.requires_gpu])
+ with open(os.path.join(ROOT, "nightly-requirements-gpu.txt"), "w") as f:
+ f.write(f"# This file is generated by `{fname}`. # DO NOT EDIT\n")
+ f.write("# For Jax, Flax, Tensorflow, PyTorch CUDA support, please refers to their official installation for your specific setup.\n")
+ f.write("-r nightly-requirements.txt\n-e .[all]\n")
+ f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values() if v.requires_gpu])
+ return 0
if __name__ == "__main__":
- raise SystemExit(main())
+ raise SystemExit(main())
diff --git a/tools/generate-coverage.py b/tools/generate-coverage.py
index 12b21cdd..7246a4ad 100755
--- a/tools/generate-coverage.py
+++ b/tools/generate-coverage.py
@@ -21,51 +21,48 @@ from pathlib import Path
import orjson
from lxml import etree
-
ROOT = Path(__file__).resolve().parent.parent
PACKAGES = {"src/openllm/": "openllm"}
-
def main() -> int:
- coverage_report = ROOT / "coverage.xml"
- root = etree.fromstring(coverage_report.read_text())
+ coverage_report = ROOT / "coverage.xml"
+ root = etree.fromstring(coverage_report.read_text())
- raw_package_data: defaultdict[str, dict[str, int]] = defaultdict(lambda: {"hits": 0, "misses": 0})
- for package in root.find("packages"):
- for module in package.find("classes"):
- filename = module.attrib["filename"]
- for relative_path, package_name in PACKAGES.items():
- if filename.startswith(relative_path):
- data = raw_package_data[package_name]
- break
- else:
- message = f"unknown package: {module}"
- raise ValueError(message)
+ raw_package_data: defaultdict[str, dict[str, int]] = defaultdict(lambda: {"hits": 0, "misses": 0})
+ for package in root.find("packages"):
+ for module in package.find("classes"):
+ filename = module.attrib["filename"]
+ for relative_path, package_name in PACKAGES.items():
+ if filename.startswith(relative_path):
+ data = raw_package_data[package_name]
+ break
+ else:
+ message = f"unknown package: {module}"
+ raise ValueError(message)
- for line in module.find("lines"):
- if line.attrib["hits"] == "1":
- data["hits"] += 1
- else:
- data["misses"] += 1
+ for line in module.find("lines"):
+ if line.attrib["hits"] == "1":
+ data["hits"] += 1
+ else:
+ data["misses"] += 1
- total_statements_covered = 0
- total_statements = 0
- coverage_data = {}
- for package_name, data in sorted(raw_package_data.items()):
- statements_covered = data["hits"]
- statements = statements_covered + data["misses"]
- total_statements_covered += statements_covered
- total_statements += statements
+ total_statements_covered = 0
+ total_statements = 0
+ coverage_data = {}
+ for package_name, data in sorted(raw_package_data.items()):
+ statements_covered = data["hits"]
+ statements = statements_covered + data["misses"]
+ total_statements_covered += statements_covered
+ total_statements += statements
- coverage_data[package_name] = {"statements_covered": statements_covered, "statements": statements}
- coverage_data["total"] = {"statements_covered": total_statements_covered, "statements": total_statements}
+ coverage_data[package_name] = {"statements_covered": statements_covered, "statements": statements}
+ coverage_data["total"] = {"statements_covered": total_statements_covered, "statements": total_statements}
- coverage_summary = ROOT / "coverage-summary.json"
- coverage_summary.write_text(orjson.dumps(coverage_data, option=orjson.OPT_INDENT_2).decode(), encoding="utf-8")
-
- return 0
+ coverage_summary = ROOT / "coverage-summary.json"
+ coverage_summary.write_text(orjson.dumps(coverage_data, option=orjson.OPT_INDENT_2).decode(), encoding="utf-8")
+ return 0
if __name__ == "__main__":
- raise SystemExit(main())
+ raise SystemExit(main())
diff --git a/tools/update-config-stubs.py b/tools/update-config-stubs.py
index fde63612..a705578d 100755
--- a/tools/update-config-stubs.py
+++ b/tools/update-config-stubs.py
@@ -25,8 +25,7 @@ from openllm._configuration import GenerationConfig
from openllm._configuration import ModelSettings
from openllm._configuration import PeftType
-
-# currently we are assuming the indentatio level is 4 for comments
+# currently we are assuming the indentatio level is 2 for comments
START_COMMENT = f"# {os.path.basename(__file__)}: start\n"
END_COMMENT = f"# {os.path.basename(__file__)}: stop\n"
START_SPECIAL_COMMENT = f"# {os.path.basename(__file__)}: special start\n"
@@ -38,28 +37,26 @@ _TARGET_FILE = Path(__file__).parent.parent / "src" / "openllm" / "_configuratio
_imported = importlib.import_module(ModelSettings.__module__)
def process_annotations(annotations: str) -> str:
- if "NotRequired" in annotations:
- return annotations[len("NotRequired[") : -1]
- elif "Required" in annotations:
- return annotations[len("Required[") : -1]
- else:
- return annotations
+ if "NotRequired" in annotations:
+ return annotations[len("NotRequired["):-1]
+ elif "Required" in annotations:
+ return annotations[len("Required["):-1]
+ else:
+ return annotations
_value_docstring = {
"default_id": """Return the default model to use when using 'openllm start '.
This could be one of the keys in 'self.model_ids' or custom users model.
This field is required when defining under '__config__'.
- """,
- "model_ids": """A list of supported pretrained models tag for this given runnable.
+ """, "model_ids": """A list of supported pretrained models tag for this given runnable.
For example:
For FLAN-T5 impl, this would be ["google/flan-t5-small", "google/flan-t5-base",
"google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"]
This field is required when defining under '__config__'.
- """,
- "architecture": """The model architecture that is supported by this LLM.
+ """, "architecture": """The model architecture that is supported by this LLM.
Note that any model weights within this architecture generation can always be run and supported by this LLM.
@@ -68,29 +65,16 @@ _value_docstring = {
```bash
openllm start gpt-neox --model-id stabilityai/stablelm-tuned-alpha-3b
- ```""",
- "default_implementation": """The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`.
+ ```""", "default_implementation": """The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`.
- It is a dictionary of key as the accelerator spec in k8s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM Runtime ('flax', 'tf', 'pt', 'vllm')
- """,
- "url": """The resolved url for this LLMConfig.""",
- "requires_gpu": """Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.""",
- "trust_remote_code": """Whether to always trust remote code""",
- "service_name": """Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""",
+ It is a dictionary of key as the accelerator spec in k4s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM Runtime ('flax', 'tf', 'pt', 'vllm')
+ """, "url": """The resolved url for this LLMConfig.""", "requires_gpu": """Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.""", "trust_remote_code": """Whether to always trust remote code""", "service_name": """Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""",
"requirements": """The default PyPI requirements needed to run this given LLM. By default, we will depend on
- bentoml, torch, transformers.""",
- "bettertransformer": """Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.""",
- "model_type": """The model type for this given LLM. By default, it should be causal language modeling.
+ bentoml, torch, transformers.""", "bettertransformer": """Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.""", "model_type": """The model type for this given LLM. By default, it should be causal language modeling.
Currently supported 'causal_lm' or 'seq2seq_lm'
- """,
- "runtime": """The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.""",
- "name_type": """The default name typed for this model. "dasherize" will convert the name to lowercase and
+ """, "runtime": """The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.""", "name_type": """The default name typed for this model. "dasherize" will convert the name to lowercase and
replace spaces with dashes. "lowercase" will convert the name to lowercase. If this is not set, then both
- `model_name` and `start_name` must be specified.""",
- "model_name": """The normalized version of __openllm_start_name__, determined by __openllm_name_type__""",
- "start_name": """Default name to be used with `openllm start`""",
- "env": """A EnvVarMixin instance for this LLMConfig.""",
- "timeout": """The default timeout to be set for this given LLM.""",
+ `model_name` and `start_name` must be specified.""", "model_name": """The normalized version of __openllm_start_name__, determined by __openllm_name_type__""", "start_name": """Default name to be used with `openllm start`""", "env": """A EnvVarMixin instance for this LLMConfig.""", "timeout": """The default timeout to be set for this given LLM.""",
"workers_per_resource": """The number of workers per resource. This is used to determine the number of workers to use for this model.
For example, if this is set to 0.5, then OpenLLM will use 1 worker per 2 resources. If this is set to 1, then
OpenLLM will use 1 worker per resource. If this is set to 2, then OpenLLM will use 2 workers per resource.
@@ -99,106 +83,58 @@ _value_docstring = {
https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy for more details.
By default, it is set to 1.
- """,
- "fine_tune_strategies": """The fine-tune strategies for this given LLM.""",
- "tokenizer_class": """Optional tokenizer class for this given LLM. See Llama for example.""",
+ """, "fine_tune_strategies": """The fine-tune strategies for this given LLM.""", "tokenizer_class": """Optional tokenizer class for this given LLM. See Llama for example.""",
}
_transformed = {"fine_tune_strategies": "t.Dict[AdapterType, FineTuneConfig]"}
-
def main() -> int:
- with _TARGET_FILE.open("r") as f:
- processed = f.readlines()
+ with _TARGET_FILE.open("r") as f:
+ processed = f.readlines()
- start_idx, end_idx = processed.index(" " * 4 + START_COMMENT), processed.index(" " * 4 + END_COMMENT)
- start_stub_idx, end_stub_idx = processed.index(" " * 8 + START_SPECIAL_COMMENT), processed.index(" " * 8 + END_SPECIAL_COMMENT)
- start_attrs_idx, end_attrs_idx = processed.index(" " * 8 + START_ATTRS_COMMENT), processed.index(" " * 8 + END_ATTRS_COMMENT)
+ start_idx, end_idx = processed.index(" "*2 + START_COMMENT), processed.index(" "*2 + END_COMMENT)
+ start_stub_idx, end_stub_idx = processed.index(" "*4 + START_SPECIAL_COMMENT), processed.index(" "*4 + END_SPECIAL_COMMENT)
+ start_attrs_idx, end_attrs_idx = processed.index(" "*4 + START_ATTRS_COMMENT), processed.index(" "*4 + END_ATTRS_COMMENT)
- # NOTE: inline stubs __config__ attrs representation
- special_attrs_lines: list[str] = []
- for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items(): special_attrs_lines.append(f"{' ' * 8}{keys}: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}\n")
+ # NOTE: inline stubs __config__ attrs representation
+ special_attrs_lines: list[str] = []
+ for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items():
+ special_attrs_lines.append(f"{' ' * 4}{keys}: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}\n")
- # NOTE: inline stubs for _ConfigAttr type stubs
- config_attr_lines: list[str] = []
- for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items():
- config_attr_lines.extend(
- [
- " " * 8 + line
- for line in [
- f"__openllm_{keys}__: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))} = Field(None)\n",
- f'"""{_value_docstring[keys]}"""\n',
- ]
- ]
- )
+ # NOTE: inline stubs for _ConfigAttr type stubs
+ config_attr_lines: list[str] = []
+ for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items():
+ config_attr_lines.extend([" "*4 + line for line in [f"__openllm_{keys}__: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))} = Field(None)\n", f'"""{_value_docstring[keys]}"""\n',]])
- # NOTE: inline runtime __getitem__ overload process
- lines: list[str] = []
- lines.append(" " * 4 + "# NOTE: ModelSettings arguments\n")
- for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items():
- lines.extend(
- [
- " " * 4 + line
- for line in [
- "@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
- f'def __getitem__(self, item: t.Literal["{keys}"]) -> {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n',
- ]
- ]
- )
- # special case variables: generation_class, extras, sampling_class
- lines.append(" " * 4 + "# NOTE: generation_class, sampling_class and extras arguments\n")
- lines.extend(
- [
- " " * 4 + line
- for line in [
- "@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
- 'def __getitem__(self, item: t.Literal["generation_class"]) -> t.Type[openllm._configuration.GenerationConfig]: ...\n',
- "@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
- 'def __getitem__(self, item: t.Literal["sampling_class"]) -> t.Type[openllm._configuration.SamplingParams]: ...\n',
- "@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
- 'def __getitem__(self, item: t.Literal["extras"]) -> t.Dict[str, t.Any]: ...\n',
- ]
- ]
- )
- lines.append(" " * 4 + "# NOTE: GenerationConfig arguments\n")
- generation_config_anns = openllm.utils.codegen.get_annotations(GenerationConfig)
- for keys, type_pep563 in generation_config_anns.items():
- lines.extend(
- [
- " " * 4 + line
- for line in [
- "@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
- f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n',
- ]
- ]
- )
- lines.append(" " * 4 + "# NOTE: SamplingParams arguments\n")
- for keys, type_pep563 in openllm.utils.codegen.get_annotations(SamplingParams).items():
- if keys not in generation_config_anns:
- lines.extend(
- [
- " " * 4 + line
- for line in [
- "@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
- f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n',
- ]
- ]
- )
+ # NOTE: inline runtime __getitem__ overload process
+ lines: list[str] = []
+ lines.append(" "*2 + "# NOTE: ModelSettings arguments\n")
+ for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items():
+ lines.extend([" "*2 + line for line in ["@overload\n" if "overload" in dir(_imported) else "@t.overload\n", f'def __getitem__(self, item: t.Literal["{keys}"]) -> {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n',]])
+ # special case variables: generation_class, extras, sampling_class
+ lines.append(" "*2 + "# NOTE: generation_class, sampling_class and extras arguments\n")
+ lines.extend([
+ " "*2 + line for line in [
+ "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", 'def __getitem__(self, item: t.Literal["generation_class"]) -> t.Type[openllm._configuration.GenerationConfig]: ...\n', "@overload\n"
+ if "overload" in dir(_imported) else "@t.overload\n", 'def __getitem__(self, item: t.Literal["sampling_class"]) -> t.Type[openllm._configuration.SamplingParams]: ...\n', "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", 'def __getitem__(self, item: t.Literal["extras"]) -> t.Dict[str, t.Any]: ...\n',
+ ]
+ ])
+ lines.append(" "*2 + "# NOTE: GenerationConfig arguments\n")
+ generation_config_anns = openllm.utils.codegen.get_annotations(GenerationConfig)
+ for keys, type_pep563 in generation_config_anns.items():
+ lines.extend([" "*2 + line for line in ["@overload\n" if "overload" in dir(_imported) else "@t.overload\n", f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n',]])
+ lines.append(" "*2 + "# NOTE: SamplingParams arguments\n")
+ for keys, type_pep563 in openllm.utils.codegen.get_annotations(SamplingParams).items():
+ if keys not in generation_config_anns:
+ lines.extend([" "*2 + line for line in ["@overload\n" if "overload" in dir(_imported) else "@t.overload\n", f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n',]])
- lines.append(" " * 4 + "# NOTE: PeftType arguments\n")
- for keys in PeftType._member_names_:
- lines.extend(
- [
- " " * 4 + line
- for line in [
- "@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
- f'def __getitem__(self, item: t.Literal["{keys.lower()}"]) -> dict[str, t.Any]: ...\n',
- ]
- ]
- )
+ lines.append(" "*2 + "# NOTE: PeftType arguments\n")
+ for keys in PeftType._member_names_:
+ lines.extend([" "*2 + line for line in ["@overload\n" if "overload" in dir(_imported) else "@t.overload\n", f'def __getitem__(self, item: t.Literal["{keys.lower()}"]) -> dict[str, t.Any]: ...\n',]])
- processed = processed[:start_attrs_idx] + [" " * 8 + START_ATTRS_COMMENT, *special_attrs_lines, " " * 8 + END_ATTRS_COMMENT] + processed[end_attrs_idx + 1 : start_stub_idx] + [" " * 8 + START_SPECIAL_COMMENT, *config_attr_lines, " " * 8 + END_SPECIAL_COMMENT] + processed[end_stub_idx + 1 : start_idx] + [" " * 4 + START_COMMENT, *lines, " " * 4 + END_COMMENT] + processed[end_idx + 1 :]
- with _TARGET_FILE.open("w") as f: f.writelines(processed)
- return 0
+ processed = processed[:start_attrs_idx] + [" "*4 + START_ATTRS_COMMENT, *special_attrs_lines, " "*4 + END_ATTRS_COMMENT] + processed[end_attrs_idx + 1:start_stub_idx] + [" "*4 + START_SPECIAL_COMMENT, *config_attr_lines, " "*4 + END_SPECIAL_COMMENT] + processed[end_stub_idx + 1:start_idx] + [" "*2 + START_COMMENT, *lines, " "*2 + END_COMMENT] + processed[end_idx + 1:]
+ with _TARGET_FILE.open("w") as f:
+ f.writelines(processed)
+ return 0
if __name__ == "__main__": raise SystemExit(main())
diff --git a/tools/update-models-import.py b/tools/update-models-import.py
index 1e95ee05..2e9a8f60 100755
--- a/tools/update-models-import.py
+++ b/tools/update-models-import.py
@@ -20,25 +20,30 @@ import openllm
_TARGET_FILE = Path(__file__).parent.parent / "src" / "openllm" / "models" / "__init__.py"
-def comment_generator(comment_type: str, action: t.Literal["start", "stop"] = "start", indentation: int = 0) -> str: return " " * indentation + f"# {os.path.basename(__file__)}: {action} {comment_type}\n"
+def comment_generator(comment_type: str, action: t.Literal["start", "stop"] = "start", indentation: int = 0) -> str:
+ return " "*indentation + f"# {os.path.basename(__file__)}: {action} {comment_type}\n"
START_MODULE_COMMENT, STOP_MODULE_COMMENT = comment_generator("module"), comment_generator("module", "stop")
-START_TYPES_COMMENT, STOP_TYPES_COMMENT = comment_generator("types", indentation=4), comment_generator("types", "stop", indentation=4)
+START_TYPES_COMMENT, STOP_TYPES_COMMENT = comment_generator("types", indentation=2), comment_generator("types", "stop", indentation=2)
-@openllm.utils.apply(lambda v: sorted([" " * 4 + _ for _ in v], key=lambda k: k.split()[-1]))
-def create_stubs_import() -> list[str]: return [f"from . import {p.name} as {p.name}\n" for p in _TARGET_FILE.parent.glob("*/") if p.name not in {"__pycache__", "__init__.py", ".DS_Store"}]
-def create_module_import() -> str: return f"_MODELS: set[str] = {{{', '.join(sorted([repr(p.name) for p in _TARGET_FILE.parent.glob('*/') if p.name not in ['__pycache__', '__init__.py', '.DS_Store']]))}}}\n"
+@openllm.utils.apply(lambda v: sorted([" "*2 + _ for _ in v], key=lambda k: k.split()[-1]))
+def create_stubs_import() -> list[str]:
+ return [f"from . import {p.name} as {p.name}\n" for p in _TARGET_FILE.parent.glob("*/") if p.name not in {"__pycache__", "__init__.py", ".DS_Store"}]
+
+def create_module_import() -> str:
+ return f"_MODELS: set[str] = {{{', '.join(sorted([repr(p.name) for p in _TARGET_FILE.parent.glob('*/') if p.name not in ['__pycache__', '__init__.py', '.DS_Store']]))}}}\n"
def main() -> int:
- with _TARGET_FILE.open("r") as f: processed = f.readlines()
- stubs_lines, module_line = create_stubs_import(), create_module_import()
-
- start_module_idx, stop_module_idx = processed.index(START_MODULE_COMMENT), processed.index(STOP_MODULE_COMMENT)
- start_types_idx, stop_types_idex = processed.index(START_TYPES_COMMENT), processed.index(STOP_TYPES_COMMENT)
- processed = processed[:start_module_idx] + [START_MODULE_COMMENT, module_line, STOP_MODULE_COMMENT] + processed[stop_module_idx+1:start_types_idx] + [START_TYPES_COMMENT, *stubs_lines, STOP_TYPES_COMMENT] + processed[stop_types_idex+1:]
- with _TARGET_FILE.open("w") as f: f.writelines(processed)
- return 0
+ with _TARGET_FILE.open("r") as f:
+ processed = f.readlines()
+ stubs_lines, module_line = create_stubs_import(), create_module_import()
+ start_module_idx, stop_module_idx = processed.index(START_MODULE_COMMENT), processed.index(STOP_MODULE_COMMENT)
+ start_types_idx, stop_types_idex = processed.index(START_TYPES_COMMENT), processed.index(STOP_TYPES_COMMENT)
+ processed = processed[:start_module_idx] + [START_MODULE_COMMENT, module_line, STOP_MODULE_COMMENT] + processed[stop_module_idx + 1:start_types_idx] + [START_TYPES_COMMENT, *stubs_lines, STOP_TYPES_COMMENT] + processed[stop_types_idex + 1:]
+ with _TARGET_FILE.open("w") as f:
+ f.writelines(processed)
+ return 0
if __name__ == "__main__":
- raise SystemExit(main())
+ raise SystemExit(main())
diff --git a/tools/update-readme.py b/tools/update-readme.py
index 6c819f82..1a092013 100755
--- a/tools/update-readme.py
+++ b/tools/update-readme.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from __future__ import annotations
import os
@@ -24,76 +23,62 @@ import tomlkit
import openllm
-
START_COMMENT = f"\n"
END_COMMENT = f"\n"
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def main() -> int:
- with open(os.path.join(ROOT, "pyproject.toml"), "r") as f: deps = tomlkit.parse(f.read()).value["project"]["optional-dependencies"]
- with open(os.path.join(ROOT, "README.md"), "r") as f: readme = f.readlines()
+ with open(os.path.join(ROOT, "pyproject.toml"), "r") as f:
+ deps = tomlkit.parse(f.read()).value["project"]["optional-dependencies"]
+ with open(os.path.join(ROOT, "README.md"), "r") as f:
+ readme = f.readlines()
- start_index, stop_index = readme.index(START_COMMENT), readme.index(END_COMMENT)
- formatted: dict[t.Literal["Model", "Architecture", "URL", "Installation", "Model Ids"], list[str | list[str]]] = {
- "Model": [],
- "Architecture": [],
- "URL": [],
- "Model Ids": [],
- "Installation": [],
- }
- max_install_len_div = 0
- for name, config_cls in openllm.CONFIG_MAPPING.items():
- dashed = inflection.dasherize(name)
- formatted["Model"].append(dashed)
- formatted["Architecture"].append(config_cls.__openllm_architecture__)
- formatted["URL"].append(config_cls.__openllm_url__)
- formatted["Model Ids"].append(config_cls.__openllm_model_ids__)
- if dashed in deps:
- instruction = f'```bash\npip install "openllm[{dashed}]"\n```'
- else:
- instruction = "```bash\npip install openllm\n```"
- if len(instruction) > max_install_len_div:
- max_install_len_div = len(instruction)
- formatted["Installation"].append(instruction)
+ start_index, stop_index = readme.index(START_COMMENT), readme.index(END_COMMENT)
+ formatted: dict[t.Literal["Model", "Architecture", "URL", "Installation", "Model Ids"], list[str | list[str]]] = {"Model": [], "Architecture": [], "URL": [], "Model Ids": [], "Installation": [],}
+ max_install_len_div = 0
+ for name, config_cls in openllm.CONFIG_MAPPING.items():
+ dashed = inflection.dasherize(name)
+ formatted["Model"].append(dashed)
+ formatted["Architecture"].append(config_cls.__openllm_architecture__)
+ formatted["URL"].append(config_cls.__openllm_url__)
+ formatted["Model Ids"].append(config_cls.__openllm_model_ids__)
+ if dashed in deps:
+ instruction = f'```bash\npip install "openllm[{dashed}]"\n```'
+ else:
+ instruction = "```bash\npip install openllm\n```"
+ if len(instruction) > max_install_len_div:
+ max_install_len_div = len(instruction)
+ formatted["Installation"].append(instruction)
- meta: list[str] = ["\n", "\n"]
+ meta: list[str] = ["\n", "\n"]
- # NOTE: headers
- meta += ["\n"]
- meta.extend([f"| {header} | \n" for header in formatted.keys() if header not in ("URL",)])
- meta += [" \n"]
- # NOTE: rows
- for name, architecture, url, model_ids, installation in t.cast(t.Iterable[t.Tuple[str, str, str, t.List[str], str]], zip(*formatted.values())):
- meta += "\n"
- # configure architecture URL
- cfg_cls = openllm.CONFIG_MAPPING[name]
- if cfg_cls.__openllm_trust_remote_code__:
- arch = f"{architecture} | \n"
- else:
- model_name = {
- "dolly_v2": "gpt_neox",
- "stablelm": "gpt_neox",
- "starcoder": "gpt_bigcode",
- "flan_t5": "t5",
- }.get(cfg_cls.__openllm_model_name__, cfg_cls.__openllm_model_name__)
- arch = f"{architecture} | \n"
- meta.extend(
- [
- f"\n{name} | \n",
- arch,
- ]
- )
- format_with_links: list[str] = []
- for lid in model_ids:
- format_with_links.append(f"{lid}")
- meta.append("\n\n" + "\n".join(format_with_links) + " \n\n | \n")
- meta.append(f"\n\n{installation}\n\n | \n")
- meta += " \n"
- meta.extend([" \n", "\n"])
+ # NOTE: headers
+ meta += ["\n"]
+ meta.extend([f"| {header} | \n" for header in formatted.keys() if header not in ("URL",)])
+ meta += [" \n"]
+ # NOTE: rows
+ for name, architecture, url, model_ids, installation in t.cast(t.Iterable[t.Tuple[str, str, str, t.List[str], str]], zip(*formatted.values())):
+ meta += "\n"
+ # configure architecture URL
+ cfg_cls = openllm.CONFIG_MAPPING[name]
+ if cfg_cls.__openllm_trust_remote_code__:
+ arch = f"{architecture} | \n"
+ else:
+ model_name = {"dolly_v2": "gpt_neox", "stablelm": "gpt_neox", "starcoder": "gpt_bigcode", "flan_t5": "t5",}.get(cfg_cls.__openllm_model_name__, cfg_cls.__openllm_model_name__)
+ arch = f"{architecture} | \n"
+ meta.extend([f"\n{name} | \n", arch,])
+ format_with_links: list[str] = []
+ for lid in model_ids:
+ format_with_links.append(f"{lid}")
+ meta.append("\n\n" + "\n".join(format_with_links) + " \n\n | \n")
+ meta.append(f"\n\n{installation}\n\n | \n")
+ meta += " \n"
+ meta.extend([" \n", "\n"])
- readme = readme[:start_index] + [START_COMMENT] + meta + [END_COMMENT] + readme[stop_index + 1 :]
- with open(os.path.join(ROOT, "README.md"), "w") as f: f.writelines(readme)
- return 0
+ readme = readme[:start_index] + [START_COMMENT] + meta + [END_COMMENT] + readme[stop_index + 1:]
+ with open(os.path.join(ROOT, "README.md"), "w") as f:
+ f.writelines(readme)
+ return 0
if __name__ == "__main__": raise SystemExit(main())
diff --git a/tools/write-coverage-report.py b/tools/write-coverage-report.py
index 0fd55550..7a9d64c2 100755
--- a/tools/write-coverage-report.py
+++ b/tools/write-coverage-report.py
@@ -21,51 +21,41 @@ from pathlib import Path
import orjson
-
PRECISION = Decimal(".01")
ROOT = Path(__file__).resolve().parent.parent
-
def main():
- coverage_summary = ROOT / "coverage-summary.json"
+ coverage_summary = ROOT / "coverage-summary.json"
- coverage_data = orjson.loads(coverage_summary.read_text(encoding="utf-8"))
- total_data = coverage_data.pop("total")
+ coverage_data = orjson.loads(coverage_summary.read_text(encoding="utf-8"))
+ total_data = coverage_data.pop("total")
- lines = [
- "\n",
- "Package | Statements\n",
- "------- | ----------\n",
- ]
+ lines = ["\n", "Package | Statements\n", "------- | ----------\n",]
- for package, data in sorted(coverage_data.items()):
- statements_covered = data["statements_covered"]
- statements = data["statements"]
+ for package, data in sorted(coverage_data.items()):
+ statements_covered = data["statements_covered"]
+ statements = data["statements"]
- rate = Decimal(statements_covered) / Decimal(statements) * 100
- rate = rate.quantize(PRECISION, rounding=ROUND_DOWN)
- lines.append(
- f"{package} | {100 if rate == 100 else rate}% ({statements_covered} / {statements})\n" # noqa: PLR2004
- )
+ rate = Decimal(statements_covered) / Decimal(statements) * 100
+ rate = rate.quantize(PRECISION, rounding=ROUND_DOWN)
+ lines.append(f"{package} | {100 if rate == 100 else rate}% ({statements_covered} / {statements})\n" # noqa: PLR2004
+ )
- total_statements_covered = total_data["statements_covered"]
- total_statements = total_data["statements"]
- total_rate = Decimal(total_statements_covered) / Decimal(total_statements) * 100
- total_rate = total_rate.quantize(PRECISION, rounding=ROUND_DOWN)
- color = "ok" if float(total_rate) >= 95 else "critical"
- lines.insert(0, f"\n")
+ total_statements_covered = total_data["statements_covered"]
+ total_statements = total_data["statements"]
+ total_rate = Decimal(total_statements_covered) / Decimal(total_statements) * 100
+ total_rate = total_rate.quantize(PRECISION, rounding=ROUND_DOWN)
+ color = "ok" if float(total_rate) >= 95 else "critical"
+ lines.insert(0, f"\n")
- lines.append(
- f"**Summary** | {100 if total_rate == 100 else total_rate}% "
- f"({total_statements_covered} / {total_statements})\n"
- )
-
- coverage_report = ROOT / "coverage-report.md"
- with coverage_report.open("w", encoding="utf-8") as f:
- f.write("".join(lines))
- return 0
+ lines.append(f"**Summary** | {100 if total_rate == 100 else total_rate}% "
+ f"({total_statements_covered} / {total_statements})\n")
+ coverage_report = ROOT / "coverage-report.md"
+ with coverage_report.open("w", encoding="utf-8") as f:
+ f.write("".join(lines))
+ return 0
if __name__ == "__main__":
- raise SystemExit(main())
+ raise SystemExit(main())
|