fix(style): reduce boilerplate and format to custom logics

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-11-26 01:44:59 -05:00
parent b4ea4b3e99
commit 69aae34cf4
6 changed files with 118 additions and 310 deletions

View File

@@ -6,19 +6,16 @@ from openllm_core._typing_compat import (
AdapterMap,
AdapterTuple,
AdapterType,
DictStrAny,
LiteralBackend,
LiteralDtype,
LiteralQuantise,
LiteralSerialisation,
M,
T,
TupleAny,
)
from openllm_core.exceptions import MissingDependencyError
from openllm_core.utils import (
DEBUG,
ReprMixin,
apply,
check_bool_env,
codegen,
@@ -49,7 +46,7 @@ ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]]
@attr.define(slots=True, repr=False, init=False)
class LLM(t.Generic[M, T], ReprMixin):
class LLM(t.Generic[M, T]):
async def generate(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
) -> GenerationOutput:
@@ -90,13 +87,12 @@ class LLM(t.Generic[M, T], ReprMixin):
self.runner.init_local(quiet=True)
config = self.config.model_construct_env(**attrs)
if stop_token_ids is None:
stop_token_ids = []
if stop_token_ids is None: stop_token_ids = []
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
if eos_token_id is not None:
if not isinstance(eos_token_id, list):
eos_token_id = [eos_token_id]
if not isinstance(eos_token_id, list): eos_token_id = [eos_token_id]
stop_token_ids.extend(eos_token_id)
if config['eos_token_id'] and config['eos_token_id'] not in stop_token_ids: stop_token_ids.append(config['eos_token_id'])
if self.tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(self.tokenizer.eos_token_id)
if stop is None:
@@ -142,9 +138,9 @@ class LLM(t.Generic[M, T], ReprMixin):
_revision: t.Optional[str]
_quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]]
_quantise: t.Optional[LiteralQuantise]
_model_decls: TupleAny
__model_attrs: DictStrAny
__tokenizer_attrs: DictStrAny
_model_decls: t.Tuple[t.Any, ...]
__model_attrs: t.Dict[str, t.Any]
__tokenizer_attrs: t.Dict[str, t.Any]
_tag: bentoml.Tag
_adapter_map: t.Optional[AdapterMap]
_serialisation: LiteralSerialisation
@@ -308,9 +304,8 @@ class LLM(t.Generic[M, T], ReprMixin):
del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__
except AttributeError:
pass
@property
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
def __repr_args__(self): yield from (('model_id', self._model_id if not self._local else self.tag.name), ('revision', self._revision if self._revision else self.tag.version), ('backend', self.__llm_backend__), ('type', self.llm_type))
def __repr__(self) -> str: return f'{self.__class__.__name__} {orjson.dumps({k: v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
@property
def import_kwargs(self): return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'}
@property

View File

@@ -1,4 +1,4 @@
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import attr
import torch
@@ -18,7 +18,6 @@ from openllm_core._typing_compat import (
M,
T,
)
from openllm_core.utils.representation import ReprArgs
from ._quantisation import QuantizationConfig
from ._runners import Runner
@@ -59,13 +58,7 @@ class LLM(Generic[M, T]):
__llm_adapter_map__: Optional[ResolvedAdapterMap] = ...
__llm_trust_remote_code__: bool = ...
@property
def __repr_keys__(self) -> Set[str]: ...
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def __repr_name__(self) -> str: ...
def __repr_str__(self, join_str: str) -> str: ...
def __repr_args__(self) -> ReprArgs: ...
def __init__(
self,
model_id: str,