fix(style): reduce boilerplate and format to custom logics

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-11-26 01:44:59 -05:00
parent b4ea4b3e99
commit 69aae34cf4
6 changed files with 118 additions and 310 deletions

View File

@@ -69,49 +69,29 @@ _object_setattr = object.__setattr__
@attr.frozen(slots=True, repr=False, init=False)
class GenerationConfig(ReprMixin):
'''GenerationConfig is the attrs-compatible version of ``transformers.GenerationConfig``, with some additional validation and environment constructor.
Note that we always set `do_sample=True`. This class is not designed to be used directly, rather
to be used conjunction with LLMConfig. The instance of the generation config can then be accessed
via ``LLMConfig.generation_config``.
'''
max_new_tokens: int = dantic.Field(
20, ge=0, description='The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.'
)
max_new_tokens: int = dantic.Field(20, ge=0, description='The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.')
min_length: int = dantic.Field(
0,
ge=0,
0, ge=0, #
description='The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.',
)
min_new_tokens: int = dantic.Field(
description='The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.'
)
min_new_tokens: int = dantic.Field(description='The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.')
early_stopping: bool = dantic.Field(
False,
description="""Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm) """,
)
max_time: float = dantic.Field(
description='The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed.'
description="Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; `'never'`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm) ",
)
max_time: float = dantic.Field(description='The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed.')
num_beams: int = dantic.Field(1, description='Number of beams for beam search. 1 means no beam search.')
num_beam_groups: int = dantic.Field(
1,
description='Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.',
)
penalty_alpha: float = dantic.Field(
description='The values balance the model confidence and the degeneration penalty in contrastive search decoding.'
)
penalty_alpha: float = dantic.Field(description='The values balance the model confidence and the degeneration penalty in contrastive search decoding.')
use_cache: bool = dantic.Field(
True,
description='Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.',
)
temperature: float = dantic.Field(
1.0, ge=0.0, le=1.0, description='The value used to modulate the next token probabilities.'
)
top_k: int = dantic.Field(
50, description='The number of highest probability vocabulary tokens to keep for top-k-filtering.'
)
temperature: float = dantic.Field(1.0, ge=0.0, le=1.0, description='The value used to modulate the next token probabilities.')
top_k: int = dantic.Field(50, description='The number of highest probability vocabulary tokens to keep for top-k-filtering.')
top_p: float = dantic.Field(
1.0,
description='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.',
@@ -196,44 +176,29 @@ class GenerationConfig(ReprMixin):
)
pad_token_id: int = dantic.Field(description='The id of the *padding* token.')
bos_token_id: int = dantic.Field(description='The id of the *beginning-of-sequence* token.')
eos_token_id: t.Union[int, t.List[int]] = dantic.Field(
description='The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.'
)
eos_token_id: t.Union[int, t.List[int]] = dantic.Field(description='The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.')
encoder_no_repeat_ngram_size: int = dantic.Field(
0,
description='If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.',
)
decoder_start_token_id: int = dantic.Field(
description='If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.'
)
decoder_start_token_id: int = dantic.Field(description='If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.')
# NOTE: This is now implemented and supported for both PyTorch and vLLM
logprobs: int = dantic.Field(0, description='Number of log probabilities to return per output token.')
prompt_logprobs: int = dantic.Field(0, description='Number of log probabilities to return per input token.')
def __init__(self, *, _internal: bool = False, **attrs: t.Any):
if not _internal:
raise RuntimeError(
'GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config'
)
if not _internal: raise RuntimeError('GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config')
self.__attrs_init__(**attrs)
def __getitem__(self, item: str) -> t.Any:
if hasattr(self, item):
return getattr(self, item)
if hasattr(self, item): return getattr(self, item)
raise KeyError(f"'{self.__class__.__name__}' has no attribute {item}.")
@property
def __repr_keys__(self) -> set[str]:
return {i.name for i in attr.fields(self.__class__)}
def __repr_keys__(self) -> set[str]: return {i.name for i in attr.fields(self.__class__)}
converter.register_unstructure_hook_factory(
lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig),
lambda cls: make_dict_unstructure_fn(
cls,
converter,
_cattrs_omit_if_default=False,
_cattrs_use_linecache=True,
cls, converter, #
**{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)},
),
)
@@ -394,15 +359,6 @@ _object_getattribute = object.__getattribute__
class ModelSettings(t.TypedDict, total=False):
'''ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__.
Note that all fields from this dictionary will then be converted to __openllm_*__ fields in LLMConfig.
If the field below changes, make sure to run ./tools/update-config-stubs.py to generate correct __getitem__
stubs for type-checking purposes.
'''
# NOTE: These required fields should be at the top, as it will be kw_only
default_id: Required[str]
model_ids: Required[ListStr]
architecture: Required[str]
@@ -435,15 +391,13 @@ _transformed_type: DictStrAny = {'fine_tune_strategies': t.Dict[AdapterType, Fin
@attr.define(
frozen=False,
slots=True,
frozen=False, slots=True, #
field_transformer=lambda _, __: [
attr.Attribute.from_counting_attr(
k,
dantic.Field(
kw_only=False if t.get_origin(ann) is not Required else True,
auto_default=True,
use_default_converter=False,
auto_default=True, use_default_converter=False, #
type=_transformed_type.get(k, ann),
metadata={'target': f'__openllm_{k}__'},
description=f'ModelSettings field for {k}.',
@@ -454,11 +408,10 @@ _transformed_type: DictStrAny = {'fine_tune_strategies': t.Dict[AdapterType, Fin
)
class _ModelSettingsAttr:
def __getitem__(self, key: str) -> t.Any:
if key in codegen.get_annotations(ModelSettings):
return _object_getattribute(self, key)
if key in codegen.get_annotations(ModelSettings): return _object_getattribute(self, key)
raise KeyError(key)
# NOTE: The below are dynamically generated by the field_transformer
@classmethod
def from_settings(cls, settings: ModelSettings) -> _ModelSettingsAttr: return cls(**settings)
if t.TYPE_CHECKING:
# update-config-stubs.py: attrs start
default_id: str
@@ -479,22 +432,15 @@ class _ModelSettingsAttr:
fine_tune_strategies: t.Dict[AdapterType, FineTuneConfig]
# update-config-stubs.py: attrs stop
_DEFAULT = _ModelSettingsAttr(
**ModelSettings(
default_id='__default__',
model_ids=['__default__'],
architecture='PreTrainedModel',
serialisation='legacy',
_DEFAULT = _ModelSettingsAttr.from_settings(
ModelSettings(
name_type='dasherize', url='', #
backend=('pt', 'vllm', 'ctranslate'),
name_type='dasherize',
url='',
model_type='causal_lm',
trust_remote_code=False,
requirements=None,
timeout=int(36e6),
service_name='',
workers_per_resource=1.0,
timeout=int(36e6), service_name='', #
model_type='causal_lm', requirements=None, #
trust_remote_code=False, workers_per_resource=1.0, #
default_id='__default__', model_ids=['__default__'], #
architecture='PreTrainedModel', serialisation='legacy', #
)
)
@@ -504,62 +450,39 @@ def structure_settings(cls: type[LLMConfig], _: type[_ModelSettingsAttr]) -> _Mo
has_custom_name = all(i in cls.__config__ for i in {'model_name', 'start_name'})
_config = attr.evolve(_DEFAULT, **cls.__config__)
_attr = {}
if not has_custom_name:
_attr['model_name'] = inflection.underscore(_cl_name) if _config['name_type'] == 'dasherize' else _cl_name.lower()
_attr['start_name'] = (
inflection.dasherize(_attr['model_name']) if _config['name_type'] == 'dasherize' else _attr['model_name']
)
model_name = _attr['model_name'] if 'model_name' in _attr else _config.model_name
if _config['name_type'] == 'dasherize':
_attr['model_name'] = inflection.underscore(_cl_name)
_attr['start_name'] = inflection.dasherize(_attr['model_name'])
else:
_attr['model_name'] = _cl_name.lower()
_attr['start_name'] = _attr['model_name']
_attr.update(
{
'service_name': f'generated_{model_name}_service.py',
'service_name': f'generated_{_attr["model_name"] if "model_name" in _attr else _config.model_name}_service.py',
'fine_tune_strategies': {
ft_config.get('adapter_type', 'lora'): FineTuneConfig.from_config(ft_config, cls)
for ft_config in _config.fine_tune_strategies # ft_config is a dict here before transformer
}
if _config.fine_tune_strategies
else {},
ft_config.get('adapter_type', 'lora'): FineTuneConfig.from_config(ft_config, cls) for ft_config in _config.fine_tune_strategies
} if _config.fine_tune_strategies else {},
}
)
return attr.evolve(_config, **_attr)
converter.register_structure_hook(_ModelSettingsAttr, structure_settings)
def _setattr_class(attr_name: str, value_var: t.Any) -> str:
return f"setattr(cls, '{attr_name}', {value_var})"
def _make_assignment_script(
cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: LiteralString = 'openllm'
) -> t.Callable[..., None]:
_reserved_namespace = {'__config__', 'GenerationConfig', 'SamplingParams'}
def _setattr_class(attr_name: str, value_var: t.Any) -> str: return f"setattr(cls, '{attr_name}', {value_var})"
def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance) -> t.Callable[[type[LLMConfig]], None]:
'''Generate the assignment script with prefix attributes __openllm_<value>__.'''
args: ListStr = []
globs: DictStrAny = {
'cls': cls,
'_cached_attribute': attributes,
'_cached_getattribute_get': _object_getattribute.__get__,
}
annotations: DictStrAny = {'return': None}
lines: ListStr = []
args, lines, annotations = [], [], {'return': None}
globs = {'cls': cls, '_cached_attribute': attributes}
for attr_name, field in attr.fields_dict(attributes.__class__).items():
arg_name = field.metadata.get('target', f'__{_prefix}_{inflection.underscore(attr_name)}__')
arg_name = field.metadata.get('target', f'__openllm_{inflection.underscore(attr_name)}__')
args.append(f"{attr_name}=getattr(_cached_attribute, '{attr_name}')")
lines.append(_setattr_class(arg_name, attr_name))
annotations[attr_name] = field.type
return codegen.generate_function(
cls, '__assign_attr', lines, args=('cls', *args), globs=globs, annotations=annotations
)
_reserved_namespace = {'__config__', 'GenerationConfig', 'SamplingParams'}
return codegen.generate_function(cls, '__assign_attr', lines, ('cls', *args), globs, annotations)
@attr.define(slots=True)
class _ConfigAttr(t.Generic[_GenerationConfigT, _SamplingParamsT]):
@@ -958,8 +881,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__)
cls.__name__ = f'{cls.__name__}Config'
if not hasattr(cls, '__config__'):
raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.")
if not hasattr(cls, '__config__'): raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.")
# auto assignment attributes generated from __config__ after create the new slot class.
_make_assignment_script(cls, converter.structure(cls, _ModelSettingsAttr))(cls)

View File

@@ -1,46 +1,27 @@
from __future__ import annotations
from typing import Callable, Dict, Tuple, List, Literal, Any, TypeVar
import sys
import typing as t
import attr
if t.TYPE_CHECKING:
from ctranslate2 import Generator, Translator
from peft.peft_model import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
M = TypeVar('M')
T = TypeVar('T')
from .utils.lazy import VersionInfo
else:
# NOTE: t.Any is also a type
PeftModel = (
PreTrainedModel
) = PreTrainedTokenizer = PreTrainedTokenizerBase = PreTrainedTokenizerFast = Generator = Translator = t.Any
# NOTE: that VersionInfo is from openllm.utils.lazy.VersionInfo
VersionInfo = t.Any
def get_literal_args(typ: Any) -> Tuple[str, ...]: return getattr(typ, '__args__', tuple())
AnyCallable = Callable[..., Any]
DictStrAny = Dict[str, Any]
ListStr = List[str]
At = TypeVar('At', bound=attr.AttrsInstance)
LiteralDtype = Literal['float16', 'float32', 'bfloat16', 'int8', 'int16']
LiteralSerialisation = Literal['safetensors', 'legacy']
LiteralQuantise = Literal['int8', 'int4', 'gptq', 'awq', 'squeezellm']
LiteralBackend = Literal['pt', 'vllm', 'ctranslate', 'triton'] # TODO: ggml
AdapterType = Literal['lora', 'adalora', 'adaption_prompt', 'prefix_tuning', 'p_tuning', 'prompt_tuning', 'ia3', 'loha', 'lokr']
LiteralVersionStrategy = Literal['release', 'nightly', 'latest', 'custom']
M = t.TypeVar('M', bound=t.Union[PreTrainedModel, PeftModel, Generator, Translator])
T = t.TypeVar('T', bound=t.Union[PreTrainedTokenizerFast, PreTrainedTokenizer, PreTrainedTokenizerBase])
class AdapterTuple(Tuple[Any, ...]): adapter_id: str; name: str ; config: DictStrAny
def get_literal_args(typ: t.Any) -> tuple[str, ...]:
return getattr(typ, '__args__', tuple())
AnyCallable = t.Callable[..., t.Any]
DictStrAny = t.Dict[str, t.Any]
ListAny = t.List[t.Any]
ListStr = t.List[str]
TupleAny = t.Tuple[t.Any, ...]
At = t.TypeVar('At', bound=attr.AttrsInstance)
LiteralDtype = t.Literal['float16', 'float32', 'bfloat16', 'int8', 'int16']
LiteralSerialisation = t.Literal['safetensors', 'legacy']
LiteralQuantise = t.Literal['int8', 'int4', 'gptq', 'awq', 'squeezellm']
LiteralBackend = t.Literal['pt', 'vllm', 'ctranslate', 'triton'] # TODO: ggml
AdapterType = t.Literal[
'lora', 'adalora', 'adaption_prompt', 'prefix_tuning', 'p_tuning', 'prompt_tuning', 'ia3', 'loha', 'lokr'
]
LiteralVersionStrategy = t.Literal['release', 'nightly', 'latest', 'custom']
AdapterMap = Dict[AdapterType, Tuple[AdapterTuple, ...]]
if sys.version_info[:2] >= (3, 11):
from typing import (
@@ -75,12 +56,3 @@ if sys.version_info[:2] >= (3, 9):
from typing import Annotated as Annotated
else:
from typing_extensions import Annotated as Annotated
class AdapterTuple(TupleAny):
adapter_id: str
name: str
config: DictStrAny
AdapterMap = t.Dict[AdapterType, t.Tuple[AdapterTuple, ...]]

View File

@@ -11,9 +11,7 @@ import orjson
if t.TYPE_CHECKING:
import openllm_core
from openllm_core._typing_compat import AnyCallable, DictStrAny, ListStr, LiteralString
PartialAny = functools.partial[t.Any]
from openllm_core._typing_compat import AnyCallable, DictStrAny, LiteralString
_T = t.TypeVar('_T', bound=t.Callable[..., t.Any])
logger = logging.getLogger(__name__)
@@ -21,24 +19,17 @@ logger = logging.getLogger(__name__)
# sentinel object for unequivocal object() getattr
_sentinel = object()
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool:
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
attr = getattr(cls, attrib_name, _sentinel)
if attr is _sentinel:
return False
if attr is _sentinel: return False
for base_cls in cls.__mro__[1:]:
a = getattr(base_cls, attrib_name, None)
if attr is a:
return False
if attr is a: return False
return True
def get_annotations(cls: type[t.Any]) -> DictStrAny:
if has_own_attribute(cls, '__annotations__'):
return cls.__annotations__
return t.cast('DictStrAny', {})
if has_own_attribute(cls, '__annotations__'): return cls.__annotations__
return {}
def is_class_var(annot: str | t.Any) -> bool:
annot = str(annot)
@@ -47,7 +38,6 @@ def is_class_var(annot: str | t.Any) -> bool:
annot = annot[1:-1]
return annot.startswith(('typing.ClassVar', 't.ClassVar', 'ClassVar', 'typing_extensions.ClassVar'))
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
try:
method_or_cls.__module__ = cls.__module__
@@ -63,13 +53,9 @@ def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str
pass
return method_or_cls
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = '') -> None:
eval(compile(script, filename, 'exec'), globs, locs)
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = '') -> None: eval(compile(script, filename, 'exec'), globs, locs)
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
locs: DictStrAny = {}
locs: dict[str, t.Any | AnyCallable] = {}
# In order of debuggers like PDB being able to step through the code, we add a fake linecache entry.
count = 1
base_filename = filename
@@ -84,18 +70,16 @@ def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> An
_compile_and_eval(script, globs, locs, filename)
return locs[name]
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]:
'''Create a tuple subclass to hold class attributes.
The subclass is a bare tuple with properties for names.
class MyClassAttributes(tuple):
__slots__ = ()
x = property(itemgetter(0))
__slots__ = ()
x = property(itemgetter(0))
'''
from . import SHOW_CODEGEN
attr_class_name = f'{cls_name}Attributes'
attr_class_template = [f'class {attr_class_name}(tuple):', ' __slots__ = ()']
if attr_names:
@@ -103,113 +87,73 @@ def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.
attr_class_template.append(f' {attr_name} = _attrs_property(_attrs_itemgetter({i}))')
else:
attr_class_template.append(' pass')
globs: DictStrAny = {'_attrs_itemgetter': itemgetter, '_attrs_property': property}
if SHOW_CODEGEN:
print(f'Generated class for {attr_class_name}:\n\n', '\n'.join(attr_class_template))
globs = {'_attrs_itemgetter': itemgetter, '_attrs_property': property}
if SHOW_CODEGEN: print(f'Generated class for {attr_class_name}:\n\n', '\n'.join(attr_class_template))
_compile_and_eval('\n'.join(attr_class_template), globs)
return globs[attr_class_name]
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str:
return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
def generate_function(
typ: type[t.Any],
func_name: str,
typ: type[t.Any], func_name: str, #
lines: list[str] | None,
args: tuple[str, ...] | None,
globs: dict[str, t.Any],
annotations: dict[str, t.Any] | None = None,
) -> AnyCallable:
from openllm_core.utils import SHOW_CODEGEN
script = 'def %s(%s):\n %s\n' % (
func_name,
', '.join(args) if args is not None else '',
'\n '.join(lines) if lines else 'pass',
)
script = 'def %s(%s):\n %s\n' % (func_name, ', '.join(args) if args is not None else '', '\n '.join(lines) if lines else 'pass')
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
if annotations:
meth.__annotations__ = annotations
if SHOW_CODEGEN:
print(f'Generated script for {typ}:\n\n', script)
if annotations: meth.__annotations__ = annotations
if SHOW_CODEGEN: print(f'Generated script for {typ}:\n\n', script)
return meth
def make_env_transformer(
cls: type[openllm_core.LLMConfig],
model_name: str,
cls: type[openllm_core.LLMConfig], model_name: str, #
suffix: LiteralString | None = None,
default_callback: t.Callable[[str, t.Any], t.Any] | None = None,
globs: DictStrAny | None = None,
) -> AnyCallable:
from openllm_core.utils import dantic, field_env_key
def identity(_: str, x_value: t.Any) -> t.Any:
return x_value
def identity(_: str, x_value: t.Any) -> t.Any: return x_value
default_callback = identity if default_callback is None else default_callback
globs = {} if globs is None else globs
globs.update(
{
'__populate_env': dantic.env_converter,
'__default_callback': default_callback,
'__field_env': field_env_key,
'__suffix': suffix or '',
'__model_name': model_name,
'__populate_env': dantic.env_converter, '__field_env': field_env_key, #
'__suffix': suffix or '', '__model_name': model_name, #
'__default_callback': identity if default_callback is None else default_callback,
}
)
lines: ListStr = [
'__env=lambda field_name:__field_env(field_name,__suffix)',
"return [f.evolve(default=__populate_env(__default_callback(f.name,f.default),__env(f.name)),metadata={'env':f.metadata.get('env',__env(f.name)),'description':f.metadata.get('description', '(not provided)')}) for f in fields]",
]
fields_ann = 'list[attr.Attribute[t.Any]]'
return generate_function(
cls,
'__auto_env',
lines,
args=('_', 'fields'),
globs=globs,
annotations={'_': 'type[LLMConfig]', 'fields': fields_ann, 'return': fields_ann},
cls, '__auto_env', #
['__env=lambda field_name:__field_env(field_name,__suffix)', "return [f.evolve(default=__populate_env(__default_callback(f.name,f.default),__env(f.name)),metadata={'env':f.metadata.get('env',__env(f.name)),'description':f.metadata.get('description', '(not provided)')}) for f in fields]"],
('_', 'fields'), globs, {'_': 'type[LLMConfig]', 'fields': fields_ann, 'return': fields_ann}, #
)
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
from .representation import ReprMixin
if name is None:
name = func.__name__.strip('_')
if name is None: name = func.__name__.strip('_')
_signatures = inspect.signature(func).parameters
def _repr(self: ReprMixin) -> str:
return f'<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>'
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]:
return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
if func.__doc__ is None:
doc = f'Generated SDK for {func.__name__}'
else:
doc = func.__doc__
return t.cast(
_T,
functools.update_wrapper(
types.new_class(
name,
(t.cast('PartialAny', functools.partial), ReprMixin),
exec_body=lambda ns: ns.update(
{
'__repr_keys__': property(lambda _: [i for i in _signatures.keys() if not i.startswith('_')]),
'__repr_args__': _repr_args,
'__repr__': _repr,
'__doc__': inspect.cleandoc(doc),
'__module__': 'openllm',
}
),
)(func, **attrs),
func,
),
def _repr(self: ReprMixin) -> str: return f'<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>'
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
return functools.update_wrapper(
types.new_class(
name,
(functools.partial, ReprMixin),
exec_body=lambda ns: ns.update(
{
'__repr_keys__': property(lambda _: [i for i in _signatures.keys() if not i.startswith('_')]),
'__repr_args__': _repr_args, '__repr__': _repr, #
'__doc__': inspect.cleandoc(f'Generated SDK for {func.__name__}' if func.__doc__ is None else func.__doc__),
'__module__': 'openllm',
}
),
)(func, **attrs),
func,
)

View File

@@ -1,35 +1,17 @@
from __future__ import annotations
import typing as t
from abc import abstractmethod
import attr
import orjson
import attr, orjson
from openllm_core import utils
if t.TYPE_CHECKING:
from openllm_core._typing_compat import TypeAlias
if t.TYPE_CHECKING: from openllm_core._typing_compat import TypeAlias
ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None]
class ReprMixin:
@property
@abstractmethod
def __repr_keys__(self) -> set[str]:
raise NotImplementedError
def __repr__(self):
return f'{self.__class__.__name__} {orjson.dumps({k: utils.converter.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
def __str__(self):
return self.__repr_str__(' ')
def __repr_name__(self) -> str:
return self.__class__.__name__
def __repr_str__(self, join_str: str):
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
def __repr_args__(self) -> ReprArgs:
return ((k, getattr(self, k)) for k in self.__repr_keys__)
def __repr_keys__(self) -> set[str]: raise NotImplementedError
def __repr__(self) -> str: return f'{self.__class__.__name__} {orjson.dumps({k: utils.converter.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
def __str__(self) -> str: return self.__repr_str__(' ')
def __repr_name__(self) -> str: return self.__class__.__name__
def __repr_str__(self, join_str: str) -> str: return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
def __repr_args__(self) -> ReprArgs: return ((k, getattr(self, k)) for k in self.__repr_keys__)

View File

@@ -6,19 +6,16 @@ from openllm_core._typing_compat import (
AdapterMap,
AdapterTuple,
AdapterType,
DictStrAny,
LiteralBackend,
LiteralDtype,
LiteralQuantise,
LiteralSerialisation,
M,
T,
TupleAny,
)
from openllm_core.exceptions import MissingDependencyError
from openllm_core.utils import (
DEBUG,
ReprMixin,
apply,
check_bool_env,
codegen,
@@ -49,7 +46,7 @@ ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]]
@attr.define(slots=True, repr=False, init=False)
class LLM(t.Generic[M, T], ReprMixin):
class LLM(t.Generic[M, T]):
async def generate(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
) -> GenerationOutput:
@@ -90,13 +87,12 @@ class LLM(t.Generic[M, T], ReprMixin):
self.runner.init_local(quiet=True)
config = self.config.model_construct_env(**attrs)
if stop_token_ids is None:
stop_token_ids = []
if stop_token_ids is None: stop_token_ids = []
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
if eos_token_id is not None:
if not isinstance(eos_token_id, list):
eos_token_id = [eos_token_id]
if not isinstance(eos_token_id, list): eos_token_id = [eos_token_id]
stop_token_ids.extend(eos_token_id)
if config['eos_token_id'] and config['eos_token_id'] not in stop_token_ids: stop_token_ids.append(config['eos_token_id'])
if self.tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(self.tokenizer.eos_token_id)
if stop is None:
@@ -142,9 +138,9 @@ class LLM(t.Generic[M, T], ReprMixin):
_revision: t.Optional[str]
_quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]]
_quantise: t.Optional[LiteralQuantise]
_model_decls: TupleAny
__model_attrs: DictStrAny
__tokenizer_attrs: DictStrAny
_model_decls: t.Tuple[t.Any, ...]
__model_attrs: t.Dict[str, t.Any]
__tokenizer_attrs: t.Dict[str, t.Any]
_tag: bentoml.Tag
_adapter_map: t.Optional[AdapterMap]
_serialisation: LiteralSerialisation
@@ -308,9 +304,8 @@ class LLM(t.Generic[M, T], ReprMixin):
del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__
except AttributeError:
pass
@property
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
def __repr_args__(self): yield from (('model_id', self._model_id if not self._local else self.tag.name), ('revision', self._revision if self._revision else self.tag.version), ('backend', self.__llm_backend__), ('type', self.llm_type))
def __repr__(self) -> str: return f'{self.__class__.__name__} {orjson.dumps({k: v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
@property
def import_kwargs(self): return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'}
@property

View File

@@ -1,4 +1,4 @@
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import attr
import torch
@@ -18,7 +18,6 @@ from openllm_core._typing_compat import (
M,
T,
)
from openllm_core.utils.representation import ReprArgs
from ._quantisation import QuantizationConfig
from ._runners import Runner
@@ -59,13 +58,7 @@ class LLM(Generic[M, T]):
__llm_adapter_map__: Optional[ResolvedAdapterMap] = ...
__llm_trust_remote_code__: bool = ...
@property
def __repr_keys__(self) -> Set[str]: ...
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def __repr_name__(self) -> str: ...
def __repr_str__(self, join_str: str) -> str: ...
def __repr_args__(self) -> ReprArgs: ...
def __init__(
self,
model_id: str,