fix(style): reduce boilerplate and format to custom logics

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-11-26 01:44:59 -05:00
parent b4ea4b3e99
commit 69aae34cf4
6 changed files with 118 additions and 310 deletions

View File

@@ -6,19 +6,16 @@ from openllm_core._typing_compat import (
AdapterMap,
AdapterTuple,
AdapterType,
DictStrAny,
LiteralBackend,
LiteralDtype,
LiteralQuantise,
LiteralSerialisation,
M,
T,
TupleAny,
)
from openllm_core.exceptions import MissingDependencyError
from openllm_core.utils import (
DEBUG,
ReprMixin,
apply,
check_bool_env,
codegen,
@@ -49,7 +46,7 @@ ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]]
@attr.define(slots=True, repr=False, init=False)
class LLM(t.Generic[M, T], ReprMixin):
class LLM(t.Generic[M, T]):
async def generate(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
) -> GenerationOutput:
@@ -90,13 +87,12 @@ class LLM(t.Generic[M, T], ReprMixin):
self.runner.init_local(quiet=True)
config = self.config.model_construct_env(**attrs)
if stop_token_ids is None:
stop_token_ids = []
if stop_token_ids is None: stop_token_ids = []
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
if eos_token_id is not None:
if not isinstance(eos_token_id, list):
eos_token_id = [eos_token_id]
if not isinstance(eos_token_id, list): eos_token_id = [eos_token_id]
stop_token_ids.extend(eos_token_id)
if config['eos_token_id'] and config['eos_token_id'] not in stop_token_ids: stop_token_ids.append(config['eos_token_id'])
if self.tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(self.tokenizer.eos_token_id)
if stop is None:
@@ -142,9 +138,9 @@ class LLM(t.Generic[M, T], ReprMixin):
_revision: t.Optional[str]
_quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]]
_quantise: t.Optional[LiteralQuantise]
_model_decls: TupleAny
__model_attrs: DictStrAny
__tokenizer_attrs: DictStrAny
_model_decls: t.Tuple[t.Any, ...]
__model_attrs: t.Dict[str, t.Any]
__tokenizer_attrs: t.Dict[str, t.Any]
_tag: bentoml.Tag
_adapter_map: t.Optional[AdapterMap]
_serialisation: LiteralSerialisation
@@ -308,9 +304,8 @@ class LLM(t.Generic[M, T], ReprMixin):
del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__
except AttributeError:
pass
@property
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
def __repr_args__(self): yield from (('model_id', self._model_id if not self._local else self.tag.name), ('revision', self._revision if self._revision else self.tag.version), ('backend', self.__llm_backend__), ('type', self.llm_type))
def __repr__(self) -> str: return f'{self.__class__.__name__} {orjson.dumps({k: v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
@property
def import_kwargs(self): return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'}
@property