mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-05 14:10:32 -05:00
fix(style): reduce boilerplate and format to custom logics
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -6,19 +6,16 @@ from openllm_core._typing_compat import (
|
||||
AdapterMap,
|
||||
AdapterTuple,
|
||||
AdapterType,
|
||||
DictStrAny,
|
||||
LiteralBackend,
|
||||
LiteralDtype,
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
M,
|
||||
T,
|
||||
TupleAny,
|
||||
)
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.utils import (
|
||||
DEBUG,
|
||||
ReprMixin,
|
||||
apply,
|
||||
check_bool_env,
|
||||
codegen,
|
||||
@@ -49,7 +46,7 @@ ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]]
|
||||
|
||||
|
||||
@attr.define(slots=True, repr=False, init=False)
|
||||
class LLM(t.Generic[M, T], ReprMixin):
|
||||
class LLM(t.Generic[M, T]):
|
||||
async def generate(
|
||||
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
|
||||
) -> GenerationOutput:
|
||||
@@ -90,13 +87,12 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
self.runner.init_local(quiet=True)
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
if stop_token_ids is None: stop_token_ids = []
|
||||
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
|
||||
if eos_token_id is not None:
|
||||
if not isinstance(eos_token_id, list):
|
||||
eos_token_id = [eos_token_id]
|
||||
if not isinstance(eos_token_id, list): eos_token_id = [eos_token_id]
|
||||
stop_token_ids.extend(eos_token_id)
|
||||
if config['eos_token_id'] and config['eos_token_id'] not in stop_token_ids: stop_token_ids.append(config['eos_token_id'])
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
if stop is None:
|
||||
@@ -142,9 +138,9 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_revision: t.Optional[str]
|
||||
_quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]]
|
||||
_quantise: t.Optional[LiteralQuantise]
|
||||
_model_decls: TupleAny
|
||||
__model_attrs: DictStrAny
|
||||
__tokenizer_attrs: DictStrAny
|
||||
_model_decls: t.Tuple[t.Any, ...]
|
||||
__model_attrs: t.Dict[str, t.Any]
|
||||
__tokenizer_attrs: t.Dict[str, t.Any]
|
||||
_tag: bentoml.Tag
|
||||
_adapter_map: t.Optional[AdapterMap]
|
||||
_serialisation: LiteralSerialisation
|
||||
@@ -308,9 +304,8 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__
|
||||
except AttributeError:
|
||||
pass
|
||||
@property
|
||||
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
|
||||
def __repr_args__(self): yield from (('model_id', self._model_id if not self._local else self.tag.name), ('revision', self._revision if self._revision else self.tag.version), ('backend', self.__llm_backend__), ('type', self.llm_type))
|
||||
def __repr__(self) -> str: return f'{self.__class__.__name__} {orjson.dumps({k: v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
|
||||
@property
|
||||
def import_kwargs(self): return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
@property
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import attr
|
||||
import torch
|
||||
@@ -18,7 +18,6 @@ from openllm_core._typing_compat import (
|
||||
M,
|
||||
T,
|
||||
)
|
||||
from openllm_core.utils.representation import ReprArgs
|
||||
|
||||
from ._quantisation import QuantizationConfig
|
||||
from ._runners import Runner
|
||||
@@ -59,13 +58,7 @@ class LLM(Generic[M, T]):
|
||||
__llm_adapter_map__: Optional[ResolvedAdapterMap] = ...
|
||||
__llm_trust_remote_code__: bool = ...
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> Set[str]: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr_name__(self) -> str: ...
|
||||
def __repr_str__(self, join_str: str) -> str: ...
|
||||
def __repr_args__(self) -> ReprArgs: ...
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
||||
Reference in New Issue
Block a user