mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-18 14:10:52 -04:00
fix(base-image): update base image to include cuda for now (#720)
* fix(base-image): update base image to include cuda for now Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix: build core and client on release images Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: cleanup style changes Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -15,10 +15,6 @@ repos:
|
||||
verbose: true
|
||||
args: [--exit-non-zero-on-fix, --show-fixes]
|
||||
types_or: [python, pyi, jupyter]
|
||||
- id: ruff-format
|
||||
alias: rf
|
||||
verbose: true
|
||||
types_or: [python, pyi, jupyter]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
|
||||
@@ -88,8 +88,8 @@ max-line-length = 119
|
||||
[lint.flake8-quotes]
|
||||
avoid-escape = false
|
||||
inline-quotes = "single"
|
||||
multiline-quotes = "double"
|
||||
docstring-quotes = "double"
|
||||
multiline-quotes = "single"
|
||||
docstring-quotes = "single"
|
||||
|
||||
[lint.extend-per-file-ignores]
|
||||
"openllm-python/tests/**/*" = ["S101", "TID252", "PT011", "S307"]
|
||||
|
||||
@@ -22,9 +22,9 @@ __all__ = ['Response', 'CompletionChunk', 'Metadata', 'StreamingResponse', 'Help
|
||||
|
||||
@attr.define
|
||||
class Metadata(_SchemaMixin):
|
||||
"""NOTE: Metadata is a modified version of the original MetadataOutput from openllm-core.
|
||||
'''NOTE: Metadata is a modified version of the original MetadataOutput from openllm-core.
|
||||
|
||||
The configuration is now structured into a dictionary for easy of use."""
|
||||
The configuration is now structured into a dictionary for easy of use.'''
|
||||
|
||||
model_id: str
|
||||
timeout: int
|
||||
|
||||
@@ -67,12 +67,12 @@ _object_setattr = object.__setattr__
|
||||
|
||||
@attr.frozen(slots=True, repr=False, init=False)
|
||||
class GenerationConfig(ReprMixin):
|
||||
"""GenerationConfig is the attrs-compatible version of ``transformers.GenerationConfig``, with some additional validation and environment constructor.
|
||||
'''GenerationConfig is the attrs-compatible version of ``transformers.GenerationConfig``, with some additional validation and environment constructor.
|
||||
|
||||
Note that we always set `do_sample=True`. This class is not designed to be used directly, rather
|
||||
to be used conjunction with LLMConfig. The instance of the generation config can then be accessed
|
||||
via ``LLMConfig.generation_config``.
|
||||
"""
|
||||
'''
|
||||
|
||||
max_new_tokens: int = dantic.Field(
|
||||
20, ge=0, description='The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.'
|
||||
@@ -241,14 +241,14 @@ _GenerationConfigT = t.TypeVar('_GenerationConfigT', bound=GenerationConfig)
|
||||
|
||||
@attr.frozen(slots=True, repr=False, init=False)
|
||||
class SamplingParams(ReprMixin):
|
||||
"""SamplingParams is the attr-compatible version of ``vllm.SamplingParams``. It provides some utilities to also respect shared variables from ``openllm.LLMConfig``.
|
||||
'''SamplingParams is the attr-compatible version of ``vllm.SamplingParams``. It provides some utilities to also respect shared variables from ``openllm.LLMConfig``.
|
||||
|
||||
The following value will be parsed directly from ``openllm.LLMConfig``:
|
||||
- temperature
|
||||
- top_k
|
||||
- top_p
|
||||
- max_tokens -> max_new_tokens
|
||||
"""
|
||||
'''
|
||||
|
||||
n: int = dantic.Field(1, description='Number of output sequences to return for the given prompt.')
|
||||
best_of: int = dantic.Field(
|
||||
@@ -317,7 +317,7 @@ class SamplingParams(ReprMixin):
|
||||
|
||||
@classmethod
|
||||
def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> Self:
|
||||
"""The main entrypoint for creating a SamplingParams from ``openllm.LLMConfig``."""
|
||||
'''The main entrypoint for creating a SamplingParams from ``openllm.LLMConfig``.'''
|
||||
if 'max_tokens' in attrs and 'max_new_tokens' in attrs:
|
||||
raise ValueError("Both 'max_tokens' and 'max_new_tokens' are passed. Make sure to only use one of them.")
|
||||
temperature = first_not_none(attrs.pop('temperature', None), default=generation_config['temperature'])
|
||||
@@ -371,13 +371,13 @@ _object_getattribute = object.__getattribute__
|
||||
|
||||
|
||||
class ModelSettings(t.TypedDict, total=False):
|
||||
"""ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__.
|
||||
'''ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__.
|
||||
|
||||
Note that all fields from this dictionary will then be converted to __openllm_*__ fields in LLMConfig.
|
||||
|
||||
If the field below changes, make sure to run ./tools/update-config-stubs.py to generate correct __getitem__
|
||||
stubs for type-checking purposes.
|
||||
"""
|
||||
'''
|
||||
|
||||
# NOTE: These required fields should be at the top, as it will be kw_only
|
||||
default_id: Required[str]
|
||||
@@ -440,27 +440,25 @@ class _ModelSettingsAttr:
|
||||
|
||||
# NOTE: The below are dynamically generated by the field_transformer
|
||||
if t.TYPE_CHECKING:
|
||||
# fmt: off
|
||||
# update-config-stubs.py: attrs start
|
||||
default_id:str
|
||||
model_ids:ListStr
|
||||
architecture:str
|
||||
url:str
|
||||
serialisation:LiteralSerialisation
|
||||
trust_remote_code:bool
|
||||
service_name:str
|
||||
requirements:t.Optional[ListStr]
|
||||
model_type:t.Literal['causal_lm', 'seq2seq_lm']
|
||||
name_type:t.Optional[t.Literal['dasherize', 'lowercase']]
|
||||
backend:t.Tuple[LiteralBackend, ...]
|
||||
model_name:str
|
||||
start_name:str
|
||||
timeout:int
|
||||
workers_per_resource:t.Union[int, float]
|
||||
fine_tune_strategies:t.Dict[AdapterType, FineTuneConfig]
|
||||
add_generation_prompt:bool
|
||||
default_id: str
|
||||
model_ids: ListStr
|
||||
architecture: str
|
||||
url: str
|
||||
serialisation: LiteralSerialisation
|
||||
trust_remote_code: bool
|
||||
service_name: str
|
||||
requirements: t.Optional[ListStr]
|
||||
model_type: t.Literal['causal_lm', 'seq2seq_lm']
|
||||
name_type: t.Optional[t.Literal['dasherize', 'lowercase']]
|
||||
backend: t.Tuple[LiteralBackend, ...]
|
||||
model_name: str
|
||||
start_name: str
|
||||
timeout: int
|
||||
workers_per_resource: t.Union[int, float]
|
||||
fine_tune_strategies: t.Dict[AdapterType, FineTuneConfig]
|
||||
add_generation_prompt: bool
|
||||
# update-config-stubs.py: attrs stop
|
||||
# fmt: on
|
||||
|
||||
|
||||
_DEFAULT = _ModelSettingsAttr(
|
||||
@@ -521,7 +519,7 @@ def _setattr_class(attr_name: str, value_var: t.Any) -> str:
|
||||
def _make_assignment_script(
|
||||
cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: LiteralString = 'openllm'
|
||||
) -> t.Callable[..., None]:
|
||||
"""Generate the assignment script with prefix attributes __openllm_<value>__."""
|
||||
'''Generate the assignment script with prefix attributes __openllm_<value>__.'''
|
||||
args: ListStr = []
|
||||
globs: DictStrAny = {
|
||||
'cls': cls,
|
||||
@@ -549,14 +547,14 @@ _reserved_namespace = {'__config__', 'GenerationConfig', 'SamplingParams'}
|
||||
class _ConfigAttr(t.Generic[_GenerationConfigT, _SamplingParamsT]):
|
||||
@staticmethod
|
||||
def Field(default: t.Any = None, **attrs: t.Any) -> t.Any:
|
||||
"""Field is a alias to the internal dantic utilities to easily create
|
||||
'''Field is a alias to the internal dantic utilities to easily create
|
||||
attrs.fields with pydantic-compatible interface. For example:
|
||||
|
||||
```python
|
||||
class MyModelConfig(openllm.LLMConfig):
|
||||
field1 = openllm.LLMConfig.Field(...)
|
||||
```
|
||||
"""
|
||||
'''
|
||||
return dantic.Field(default, **attrs)
|
||||
|
||||
# NOTE: The following is handled via __init_subclass__, and is only used for TYPE_CHECKING
|
||||
@@ -616,21 +614,20 @@ class _ConfigAttr(t.Generic[_GenerationConfigT, _SamplingParamsT]):
|
||||
This class will also be managed internally by OpenLLM."""
|
||||
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None:
|
||||
"""Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract."""
|
||||
'''Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract.'''
|
||||
|
||||
# NOTE: The following will be populated from __config__ and also
|
||||
# considered to be public API. Users can also access these via self[key]
|
||||
# To update the docstring for these field, update it through tools/update-config-stubs.py
|
||||
|
||||
# fmt: off
|
||||
# update-config-stubs.py: special start
|
||||
__openllm_default_id__:str=Field(None)
|
||||
__openllm_default_id__: str = Field(None)
|
||||
'''Return the default model to use when using 'openllm start <model_id>'.
|
||||
This could be one of the keys in 'self.model_ids' or custom users model.
|
||||
|
||||
This field is required when defining under '__config__'.
|
||||
'''
|
||||
__openllm_model_ids__:ListStr=Field(None)
|
||||
__openllm_model_ids__: ListStr = Field(None)
|
||||
'''A list of supported pretrained models tag for this given runnable.
|
||||
|
||||
For example:
|
||||
@@ -639,7 +636,7 @@ class _ConfigAttr(t.Generic[_GenerationConfigT, _SamplingParamsT]):
|
||||
|
||||
This field is required when defining under '__config__'.
|
||||
'''
|
||||
__openllm_architecture__:str=Field(None)
|
||||
__openllm_architecture__: str = Field(None)
|
||||
'''The model architecture that is supported by this LLM.
|
||||
|
||||
Note that any model weights within this architecture generation can always be run and supported by this LLM.
|
||||
@@ -650,31 +647,31 @@ class _ConfigAttr(t.Generic[_GenerationConfigT, _SamplingParamsT]):
|
||||
```bash
|
||||
openllm start stabilityai/stablelm-tuned-alpha-3b
|
||||
```'''
|
||||
__openllm_url__:str=Field(None)
|
||||
__openllm_url__: str = Field(None)
|
||||
'''The resolved url for this LLMConfig.'''
|
||||
__openllm_serialisation__:LiteralSerialisation=Field(None)
|
||||
__openllm_serialisation__: LiteralSerialisation = Field(None)
|
||||
'''Default serialisation format for different models. Some will default to use the legacy 'bin'. '''
|
||||
__openllm_trust_remote_code__:bool=Field(None)
|
||||
__openllm_trust_remote_code__: bool = Field(None)
|
||||
'''Whether to always trust remote code'''
|
||||
__openllm_service_name__:str=Field(None)
|
||||
__openllm_service_name__: str = Field(None)
|
||||
'''Generated service name for this LLMConfig. By default, it is "generated_{model_name}_service.py"'''
|
||||
__openllm_requirements__:t.Optional[ListStr]=Field(None)
|
||||
__openllm_requirements__: t.Optional[ListStr] = Field(None)
|
||||
'''The default PyPI requirements needed to run this given LLM. By default, we will depend on bentoml, torch, transformers.'''
|
||||
__openllm_model_type__:t.Literal['causal_lm', 'seq2seq_lm']=Field(None)
|
||||
__openllm_model_type__: t.Literal['causal_lm', 'seq2seq_lm'] = Field(None)
|
||||
'''The model type for this given LLM. By default, it should be causal language modeling. Currently supported "causal_lm" or "seq2seq_lm"'''
|
||||
__openllm_name_type__:t.Optional[t.Literal['dasherize', 'lowercase']]=Field(None)
|
||||
__openllm_name_type__: t.Optional[t.Literal['dasherize', 'lowercase']] = Field(None)
|
||||
'''The default name typed for this model. "dasherize" will convert the name to lowercase and
|
||||
replace spaces with dashes. "lowercase" will convert the name to lowercase. If this is not set, then both
|
||||
`model_name` and `start_name` must be specified.'''
|
||||
__openllm_backend__:t.Tuple[LiteralBackend, ...]=Field(None)
|
||||
__openllm_backend__: t.Tuple[LiteralBackend, ...] = Field(None)
|
||||
'''List of supported backend for this given LLM class. Currently, we support "pt" and "vllm".'''
|
||||
__openllm_model_name__:str=Field(None)
|
||||
__openllm_model_name__: str = Field(None)
|
||||
'''The normalized version of __openllm_start_name__, determined by __openllm_name_type__'''
|
||||
__openllm_start_name__:str=Field(None)
|
||||
__openllm_start_name__: str = Field(None)
|
||||
'''Default name to be used with `openllm start`'''
|
||||
__openllm_timeout__:int=Field(None)
|
||||
__openllm_timeout__: int = Field(None)
|
||||
'''The default timeout to be set for this given LLM.'''
|
||||
__openllm_workers_per_resource__:t.Union[int, float]=Field(None)
|
||||
__openllm_workers_per_resource__: t.Union[int, float] = Field(None)
|
||||
'''The number of workers per resource. This is used to determine the number of workers to use for this model.
|
||||
For example, if this is set to 0.5, then OpenLLM will use 1 worker per 2 resources. If this is set to 1, then
|
||||
OpenLLM will use 1 worker per resource. If this is set to 2, then OpenLLM will use 2 workers per resource.
|
||||
@@ -684,12 +681,11 @@ class _ConfigAttr(t.Generic[_GenerationConfigT, _SamplingParamsT]):
|
||||
|
||||
By default, it is set to 1.
|
||||
'''
|
||||
__openllm_fine_tune_strategies__:t.Dict[AdapterType, FineTuneConfig]=Field(None)
|
||||
__openllm_fine_tune_strategies__: t.Dict[AdapterType, FineTuneConfig] = Field(None)
|
||||
'''The fine-tune strategies for this given LLM.'''
|
||||
__openllm_add_generation_prompt__:bool=Field(None)
|
||||
__openllm_add_generation_prompt__: bool = Field(None)
|
||||
'''Whether to add generation prompt token for formatting chat templates. This arguments will be used for chat-based models.'''
|
||||
# update-config-stubs.py: special stop
|
||||
# fmt: on
|
||||
|
||||
|
||||
class _ConfigBuilder:
|
||||
@@ -1097,171 +1093,169 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
|
||||
# The rest of attrs should only be the attributes to be passed to __attrs_init__
|
||||
self.__attrs_init__(**attrs)
|
||||
|
||||
# fmt: off
|
||||
# update-config-stubs.py: start
|
||||
# NOTE: ModelSettings arguments
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['default_id'])->str:...
|
||||
def __getitem__(self, item: t.Literal['default_id']) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['model_ids'])->ListStr:...
|
||||
def __getitem__(self, item: t.Literal['model_ids']) -> ListStr: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['architecture'])->str:...
|
||||
def __getitem__(self, item: t.Literal['architecture']) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['url'])->str:...
|
||||
def __getitem__(self, item: t.Literal['url']) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['serialisation'])->LiteralSerialisation:...
|
||||
def __getitem__(self, item: t.Literal['serialisation']) -> LiteralSerialisation: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['trust_remote_code'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['trust_remote_code']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['service_name'])->str:...
|
||||
def __getitem__(self, item: t.Literal['service_name']) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['requirements'])->t.Optional[ListStr]:...
|
||||
def __getitem__(self, item: t.Literal['requirements']) -> t.Optional[ListStr]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['model_type'])->t.Literal['causal_lm', 'seq2seq_lm']:...
|
||||
def __getitem__(self, item: t.Literal['model_type']) -> t.Literal['causal_lm', 'seq2seq_lm']: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['name_type'])->t.Optional[t.Literal['dasherize', 'lowercase']]:...
|
||||
def __getitem__(self, item: t.Literal['name_type']) -> t.Optional[t.Literal['dasherize', 'lowercase']]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['backend'])->t.Tuple[LiteralBackend, ...]:...
|
||||
def __getitem__(self, item: t.Literal['backend']) -> t.Tuple[LiteralBackend, ...]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['model_name'])->str:...
|
||||
def __getitem__(self, item: t.Literal['model_name']) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['start_name'])->str:...
|
||||
def __getitem__(self, item: t.Literal['start_name']) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['timeout'])->int:...
|
||||
def __getitem__(self, item: t.Literal['timeout']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['workers_per_resource'])->t.Union[int, float]:...
|
||||
def __getitem__(self, item: t.Literal['workers_per_resource']) -> t.Union[int, float]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['fine_tune_strategies'])->t.Dict[AdapterType, FineTuneConfig]:...
|
||||
def __getitem__(self, item: t.Literal['fine_tune_strategies']) -> t.Dict[AdapterType, FineTuneConfig]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['add_generation_prompt'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['add_generation_prompt']) -> bool: ...
|
||||
# NOTE: generation_class, sampling_class and extras arguments
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['generation_class'])->t.Type[openllm_core.GenerationConfig]:...
|
||||
def __getitem__(self, item: t.Literal['generation_class']) -> t.Type[openllm_core.GenerationConfig]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['sampling_class'])->t.Type[openllm_core.SamplingParams]:...
|
||||
def __getitem__(self, item: t.Literal['sampling_class']) -> t.Type[openllm_core.SamplingParams]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['extras'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['extras']) -> t.Dict[str, t.Any]: ...
|
||||
# NOTE: GenerationConfig arguments
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['max_new_tokens'])->int:...
|
||||
def __getitem__(self, item: t.Literal['max_new_tokens']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['min_length'])->int:...
|
||||
def __getitem__(self, item: t.Literal['min_length']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['min_new_tokens'])->int:...
|
||||
def __getitem__(self, item: t.Literal['min_new_tokens']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['early_stopping'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['early_stopping']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['max_time'])->float:...
|
||||
def __getitem__(self, item: t.Literal['max_time']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['num_beams'])->int:...
|
||||
def __getitem__(self, item: t.Literal['num_beams']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['num_beam_groups'])->int:...
|
||||
def __getitem__(self, item: t.Literal['num_beam_groups']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['penalty_alpha'])->float:...
|
||||
def __getitem__(self, item: t.Literal['penalty_alpha']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['use_cache'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['use_cache']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['temperature'])->float:...
|
||||
def __getitem__(self, item: t.Literal['temperature']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['top_k'])->int:...
|
||||
def __getitem__(self, item: t.Literal['top_k']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['top_p'])->float:...
|
||||
def __getitem__(self, item: t.Literal['top_p']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['typical_p'])->float:...
|
||||
def __getitem__(self, item: t.Literal['typical_p']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['epsilon_cutoff'])->float:...
|
||||
def __getitem__(self, item: t.Literal['epsilon_cutoff']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['eta_cutoff'])->float:...
|
||||
def __getitem__(self, item: t.Literal['eta_cutoff']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['diversity_penalty'])->float:...
|
||||
def __getitem__(self, item: t.Literal['diversity_penalty']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['repetition_penalty'])->float:...
|
||||
def __getitem__(self, item: t.Literal['repetition_penalty']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['encoder_repetition_penalty'])->float:...
|
||||
def __getitem__(self, item: t.Literal['encoder_repetition_penalty']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['length_penalty'])->float:...
|
||||
def __getitem__(self, item: t.Literal['length_penalty']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['no_repeat_ngram_size'])->int:...
|
||||
def __getitem__(self, item: t.Literal['no_repeat_ngram_size']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['bad_words_ids'])->t.List[t.List[int]]:...
|
||||
def __getitem__(self, item: t.Literal['bad_words_ids']) -> t.List[t.List[int]]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['force_words_ids'])->t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]]:...
|
||||
def __getitem__(self, item: t.Literal['force_words_ids']) -> t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['renormalize_logits'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['renormalize_logits']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['forced_bos_token_id'])->int:...
|
||||
def __getitem__(self, item: t.Literal['forced_bos_token_id']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['forced_eos_token_id'])->t.Union[int, t.List[int]]:...
|
||||
def __getitem__(self, item: t.Literal['forced_eos_token_id']) -> t.Union[int, t.List[int]]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['remove_invalid_values'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['remove_invalid_values']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['exponential_decay_length_penalty'])->t.Tuple[int, float]:...
|
||||
def __getitem__(self, item: t.Literal['exponential_decay_length_penalty']) -> t.Tuple[int, float]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['suppress_tokens'])->t.List[int]:...
|
||||
def __getitem__(self, item: t.Literal['suppress_tokens']) -> t.List[int]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['begin_suppress_tokens'])->t.List[int]:...
|
||||
def __getitem__(self, item: t.Literal['begin_suppress_tokens']) -> t.List[int]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['forced_decoder_ids'])->t.List[t.List[int]]:...
|
||||
def __getitem__(self, item: t.Literal['forced_decoder_ids']) -> t.List[t.List[int]]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['num_return_sequences'])->int:...
|
||||
def __getitem__(self, item: t.Literal['num_return_sequences']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['output_attentions'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['output_attentions']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['output_hidden_states'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['output_hidden_states']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['output_scores'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['output_scores']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['pad_token_id'])->int:...
|
||||
def __getitem__(self, item: t.Literal['pad_token_id']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['bos_token_id'])->int:...
|
||||
def __getitem__(self, item: t.Literal['bos_token_id']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['eos_token_id'])->t.Union[int, t.List[int]]:...
|
||||
def __getitem__(self, item: t.Literal['eos_token_id']) -> t.Union[int, t.List[int]]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['encoder_no_repeat_ngram_size'])->int:...
|
||||
def __getitem__(self, item: t.Literal['encoder_no_repeat_ngram_size']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['decoder_start_token_id'])->int:...
|
||||
def __getitem__(self, item: t.Literal['decoder_start_token_id']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['logprobs'])->int:...
|
||||
def __getitem__(self, item: t.Literal['logprobs']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['prompt_logprobs'])->int:...
|
||||
def __getitem__(self, item: t.Literal['prompt_logprobs']) -> int: ...
|
||||
# NOTE: SamplingParams arguments
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['n'])->int:...
|
||||
def __getitem__(self, item: t.Literal['n']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['best_of'])->int:...
|
||||
def __getitem__(self, item: t.Literal['best_of']) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['presence_penalty'])->float:...
|
||||
def __getitem__(self, item: t.Literal['presence_penalty']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['frequency_penalty'])->float:...
|
||||
def __getitem__(self, item: t.Literal['frequency_penalty']) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['use_beam_search'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['use_beam_search']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['ignore_eos'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['ignore_eos']) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['skip_special_tokens'])->bool:...
|
||||
def __getitem__(self, item: t.Literal['skip_special_tokens']) -> bool: ...
|
||||
# NOTE: PeftType arguments
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['prompt_tuning'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['prompt_tuning']) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['multitask_prompt_tuning'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['multitask_prompt_tuning']) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['p_tuning'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['p_tuning']) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['prefix_tuning'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['prefix_tuning']) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['lora'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['lora']) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['adalora'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['adalora']) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['adaption_prompt'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['adaption_prompt']) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['ia3'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['ia3']) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['loha'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['loha']) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self,item:t.Literal['lokr'])->t.Dict[str, t.Any]:...
|
||||
def __getitem__(self, item: t.Literal['lokr']) -> t.Dict[str, t.Any]: ...
|
||||
# update-config-stubs.py: stop
|
||||
# fmt: on
|
||||
|
||||
def __getitem__(self, item: LiteralString | t.Any) -> t.Any:
|
||||
"""Allowing access LLMConfig as a dictionary. The order will always evaluate as.
|
||||
@@ -1400,7 +1394,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
|
||||
|
||||
@classmethod
|
||||
def model_construct_env(cls, **attrs: t.Any) -> Self:
|
||||
"""A helpers that respect configuration values environment variables."""
|
||||
'''A helpers that respect configuration values environment variables.'''
|
||||
attrs = {k: v for k, v in attrs.items() if v is not None}
|
||||
env_json_string = os.environ.get('OPENLLM_CONFIG', None)
|
||||
|
||||
@@ -1434,7 +1428,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
|
||||
return converter.structure(config_from_env, cls)
|
||||
|
||||
def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]:
|
||||
"""Parse given click attributes into a LLMConfig and return the remaining click attributes."""
|
||||
'''Parse given click attributes into a LLMConfig and return the remaining click attributes.'''
|
||||
llm_config_attrs: DictStrAny = {'generation_config': {}, 'sampling_config': {}}
|
||||
key_to_remove: ListStr = []
|
||||
for k, v in attrs.items():
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import os
|
||||
import typing as t
|
||||
import importlib, os, typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
import inflection, orjson
|
||||
from openllm_core.exceptions import MissingDependencyError, OpenLLMException
|
||||
from openllm_core.utils import ReprMixin, is_bentoml_available
|
||||
from openllm_core.utils.import_utils import is_transformers_available
|
||||
@@ -15,8 +10,7 @@ if t.TYPE_CHECKING:
|
||||
import types
|
||||
from collections import _odict_items, _odict_keys, _odict_values
|
||||
|
||||
import openllm
|
||||
import openllm_core
|
||||
import openllm, openllm_core
|
||||
from openllm_core._typing_compat import LiteralString, M, T
|
||||
|
||||
ConfigKeysView = _odict_keys[str, type[openllm_core.LLMConfig]]
|
||||
@@ -30,20 +24,20 @@ else:
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
sorted(
|
||||
[
|
||||
('chatglm', 'ChatGLMConfig'),
|
||||
('dolly_v2', 'DollyV2Config'),
|
||||
('falcon', 'FalconConfig'),
|
||||
('flan_t5', 'FlanT5Config'),
|
||||
('baichuan', 'BaichuanConfig'),
|
||||
('chatglm', 'ChatGLMConfig'), #
|
||||
('falcon', 'FalconConfig'),
|
||||
('gpt_neox', 'GPTNeoXConfig'),
|
||||
('dolly_v2', 'DollyV2Config'),
|
||||
('stablelm', 'StableLMConfig'), #
|
||||
('llama', 'LlamaConfig'),
|
||||
('mpt', 'MPTConfig'),
|
||||
('opt', 'OPTConfig'),
|
||||
('phi', 'PhiConfig'),
|
||||
('stablelm', 'StableLMConfig'),
|
||||
('starcoder', 'StarCoderConfig'),
|
||||
('mistral', 'MistralConfig'),
|
||||
('yi', 'YiConfig'),
|
||||
('baichuan', 'BaichuanConfig'),
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -114,56 +108,52 @@ CONFIG_FILE_NAME = 'config.json'
|
||||
|
||||
class AutoConfig:
|
||||
def __init__(self, *_: t.Any, **__: t.Any):
|
||||
raise EnvironmentError(
|
||||
'Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.'
|
||||
)
|
||||
raise EnvironmentError('Cannot instantiate AutoConfig directly. Use `.for_model(model_name)`')
|
||||
|
||||
# fmt: off
|
||||
# update-config-stubs.py: auto stubs start
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['baichuan'],**attrs:t.Any)->openllm_core.config.BaichuanConfig:...
|
||||
def for_model(cls, model_name: t.Literal['baichuan'], **attrs: t.Any) -> openllm_core.config.BaichuanConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['chatglm'],**attrs:t.Any)->openllm_core.config.ChatGLMConfig:...
|
||||
def for_model(cls, model_name: t.Literal['chatglm'], **attrs: t.Any) -> openllm_core.config.ChatGLMConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['dolly_v2'],**attrs:t.Any)->openllm_core.config.DollyV2Config:...
|
||||
def for_model(cls, model_name: t.Literal['dolly_v2'], **attrs: t.Any) -> openllm_core.config.DollyV2Config: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['falcon'],**attrs:t.Any)->openllm_core.config.FalconConfig:...
|
||||
def for_model(cls, model_name: t.Literal['falcon'], **attrs: t.Any) -> openllm_core.config.FalconConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['flan_t5'],**attrs:t.Any)->openllm_core.config.FlanT5Config:...
|
||||
def for_model(cls, model_name: t.Literal['flan_t5'], **attrs: t.Any) -> openllm_core.config.FlanT5Config: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['gpt_neox'],**attrs:t.Any)->openllm_core.config.GPTNeoXConfig:...
|
||||
def for_model(cls, model_name: t.Literal['gpt_neox'], **attrs: t.Any) -> openllm_core.config.GPTNeoXConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['llama'],**attrs:t.Any)->openllm_core.config.LlamaConfig:...
|
||||
def for_model(cls, model_name: t.Literal['llama'], **attrs: t.Any) -> openllm_core.config.LlamaConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['mistral'],**attrs:t.Any)->openllm_core.config.MistralConfig:...
|
||||
def for_model(cls, model_name: t.Literal['mistral'], **attrs: t.Any) -> openllm_core.config.MistralConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['mpt'],**attrs:t.Any)->openllm_core.config.MPTConfig:...
|
||||
def for_model(cls, model_name: t.Literal['mpt'], **attrs: t.Any) -> openllm_core.config.MPTConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['opt'],**attrs:t.Any)->openllm_core.config.OPTConfig:...
|
||||
def for_model(cls, model_name: t.Literal['opt'], **attrs: t.Any) -> openllm_core.config.OPTConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['phi'],**attrs:t.Any)->openllm_core.config.PhiConfig:...
|
||||
def for_model(cls, model_name: t.Literal['phi'], **attrs: t.Any) -> openllm_core.config.PhiConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['stablelm'],**attrs:t.Any)->openllm_core.config.StableLMConfig:...
|
||||
def for_model(cls, model_name: t.Literal['stablelm'], **attrs: t.Any) -> openllm_core.config.StableLMConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['starcoder'],**attrs:t.Any)->openllm_core.config.StarCoderConfig:...
|
||||
def for_model(cls, model_name: t.Literal['starcoder'], **attrs: t.Any) -> openllm_core.config.StarCoderConfig: ...
|
||||
@t.overload
|
||||
@classmethod
|
||||
def for_model(cls,model_name:t.Literal['yi'],**attrs:t.Any)->openllm_core.config.YiConfig:...
|
||||
def for_model(cls, model_name: t.Literal['yi'], **attrs: t.Any) -> openllm_core.config.YiConfig: ...
|
||||
# update-config-stubs.py: auto stubs stop
|
||||
# fmt: on
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm_core.LLMConfig:
|
||||
@@ -196,9 +186,7 @@ class AutoConfig:
|
||||
@classmethod
|
||||
def infer_class_from_llm(cls, llm: openllm.LLM[M, T]) -> type[openllm_core.LLMConfig]:
|
||||
if not is_bentoml_available():
|
||||
raise MissingDependencyError(
|
||||
"'infer_class_from_llm' requires 'bentoml' to be available. Make sure to install it with 'pip install bentoml'"
|
||||
)
|
||||
raise MissingDependencyError("Requires 'bentoml' to be available. Do 'pip install bentoml'")
|
||||
if llm._local:
|
||||
config_file = os.path.join(llm.model_id, CONFIG_FILE_NAME)
|
||||
else:
|
||||
@@ -207,7 +195,7 @@ class AutoConfig:
|
||||
except OpenLLMException as err:
|
||||
if not is_transformers_available():
|
||||
raise MissingDependencyError(
|
||||
"'infer_class_from_llm' requires 'transformers' to be available. Make sure to install it with 'pip install transformers'"
|
||||
"Requires 'transformers' to be available. Do 'pip install transformers'"
|
||||
) from err
|
||||
from transformers.utils import cached_file
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class MistralConfig(openllm_core.LLMConfig):
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return """{start_key}{start_inst} {system_message} {instruction} {end_inst}\n""".format(
|
||||
return '''{start_key}{start_inst} {system_message} {instruction} {end_inst}\n'''.format(
|
||||
start_inst=SINST_KEY,
|
||||
end_inst=EINST_KEY,
|
||||
start_key=BOS_TOKEN,
|
||||
@@ -57,4 +57,4 @@ class MistralConfig(openllm_core.LLMConfig):
|
||||
# NOTE: https://docs.mistral.ai/usage/guardrailing/
|
||||
@property
|
||||
def system_message(self) -> str:
|
||||
return """Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity."""
|
||||
return '''Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.'''
|
||||
|
||||
@@ -43,9 +43,9 @@ class StableLMConfig(openllm_core.LLMConfig):
|
||||
|
||||
@property
|
||||
def system_message(self) -> str:
|
||||
return """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
return '''<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
'''
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
|
||||
'''Base exceptions for OpenLLM. This extends BentoML exceptions.'''
|
||||
|
||||
from __future__ import annotations
|
||||
from http import HTTPStatus
|
||||
|
||||
|
||||
class OpenLLMException(Exception):
|
||||
"""Base class for all OpenLLM exceptions. This shares similar interface with BentoMLException."""
|
||||
'''Base class for all OpenLLM exceptions. This shares similar interface with BentoMLException.'''
|
||||
|
||||
error_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
@@ -15,28 +15,28 @@ class OpenLLMException(Exception):
|
||||
|
||||
|
||||
class GpuNotAvailableError(OpenLLMException):
|
||||
"""Raised when there is no GPU available in given system."""
|
||||
'''Raised when there is no GPU available in given system.'''
|
||||
|
||||
|
||||
class ValidationError(OpenLLMException):
|
||||
"""Raised when a validation fails."""
|
||||
'''Raised when a validation fails.'''
|
||||
|
||||
|
||||
class ForbiddenAttributeError(OpenLLMException):
|
||||
"""Raised when using an _internal field."""
|
||||
'''Raised when using an _internal field.'''
|
||||
|
||||
|
||||
class MissingAnnotationAttributeError(OpenLLMException):
|
||||
"""Raised when a field under openllm.LLMConfig is missing annotations."""
|
||||
'''Raised when a field under openllm.LLMConfig is missing annotations.'''
|
||||
|
||||
|
||||
class MissingDependencyError(BaseException):
|
||||
"""Raised when a dependency is missing."""
|
||||
'''Raised when a dependency is missing.'''
|
||||
|
||||
|
||||
class Error(BaseException):
|
||||
"""To be used instead of naked raise."""
|
||||
'''To be used instead of naked raise.'''
|
||||
|
||||
|
||||
class FineTuneStrategyNotSupportedError(OpenLLMException):
|
||||
"""Raised when a fine-tune strategy is not supported for given LLM."""
|
||||
'''Raised when a fine-tune strategy is not supported for given LLM.'''
|
||||
|
||||
@@ -15,129 +15,244 @@ WARNING_ENV_VAR = 'OPENLLM_DISABLE_WARNING'
|
||||
DEV_DEBUG_VAR = 'DEBUG'
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# fmt: off
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _WithArgsTypes()->tuple[type[t.Any],...]:
|
||||
try:from typing import GenericAlias as _TypingGenericAlias # type: ignore
|
||||
except ImportError:_TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
def _WithArgsTypes() -> tuple[type[t.Any], ...]:
|
||||
try:
|
||||
from typing import GenericAlias as _TypingGenericAlias # type: ignore
|
||||
except ImportError:
|
||||
_TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
# _GenericAlias is the actual GenericAlias implementation
|
||||
return (_TypingGenericAlias,) if sys.version_info<(3,10) else (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore
|
||||
def lenient_issubclass(cls,class_or_tuple):
|
||||
try:return isinstance(cls,type) and issubclass(cls,class_or_tuple)
|
||||
return (
|
||||
(_TypingGenericAlias,) if sys.version_info < (3, 10) else (t._GenericAlias, types.GenericAlias, types.UnionType)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def lenient_issubclass(cls, class_or_tuple):
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
|
||||
except TypeError:
|
||||
if isinstance(cls,_WithArgsTypes()):return False
|
||||
if isinstance(cls, _WithArgsTypes()):
|
||||
return False
|
||||
raise
|
||||
|
||||
|
||||
def resolve_user_filepath(filepath, ctx):
|
||||
_path=os.path.expanduser(os.path.expandvars(filepath))
|
||||
if os.path.exists(_path):return os.path.realpath(_path)
|
||||
_path = os.path.expanduser(os.path.expandvars(filepath))
|
||||
if os.path.exists(_path):
|
||||
return os.path.realpath(_path)
|
||||
# Try finding file in ctx if provided
|
||||
if ctx:
|
||||
_path=os.path.expanduser(os.path.join(ctx, filepath))
|
||||
if os.path.exists(_path):return os.path.realpath(_path)
|
||||
_path = os.path.expanduser(os.path.join(ctx, filepath))
|
||||
if os.path.exists(_path):
|
||||
return os.path.realpath(_path)
|
||||
raise FileNotFoundError(f'file {filepath} not found')
|
||||
|
||||
|
||||
# this is the supress version of resolve_user_filepath
|
||||
def resolve_filepath(path,ctx=None):
|
||||
try:return resolve_user_filepath(path, ctx)
|
||||
except FileNotFoundError:return path
|
||||
def check_bool_env(env,default=True):
|
||||
v=os.getenv(env,default=str(default)).upper()
|
||||
if v.isdigit():return bool(int(v)) # special check for digits
|
||||
def resolve_filepath(path, ctx=None):
|
||||
try:
|
||||
return resolve_user_filepath(path, ctx)
|
||||
except FileNotFoundError:
|
||||
return path
|
||||
|
||||
|
||||
def check_bool_env(env, default=True):
|
||||
v = os.getenv(env, default=str(default)).upper()
|
||||
if v.isdigit():
|
||||
return bool(int(v)) # special check for digits
|
||||
return v in ENV_VARS_TRUE_VALUES
|
||||
def calc_dir_size(path):return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file())
|
||||
|
||||
|
||||
def calc_dir_size(path):
|
||||
return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file())
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def generate_hash_from_file(f,algorithm='sha1'):return str(getattr(hashlib,algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest())
|
||||
def getenv(env,default=None,var=None):
|
||||
env_key={env.upper(),f'OPENLLM_{env.upper()}'}
|
||||
if var is not None:env_key=set(var)|env_key
|
||||
def callback(k:str)->t.Any:
|
||||
def generate_hash_from_file(f, algorithm='sha1'):
|
||||
return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest())
|
||||
|
||||
|
||||
def getenv(env, default=None, var=None):
|
||||
env_key = {env.upper(), f'OPENLLM_{env.upper()}'}
|
||||
if var is not None:
|
||||
env_key = set(var) | env_key
|
||||
|
||||
def callback(k: str) -> t.Any:
|
||||
_var = os.getenv(k)
|
||||
if _var and k.startswith('OPENLLM_'):logger.warning("Using '%s' environment is deprecated, use '%s' instead.",k.upper(),k[8:].upper())
|
||||
if _var and k.startswith('OPENLLM_'):
|
||||
logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper())
|
||||
return _var
|
||||
return first_not_none(*(callback(k) for k in env_key),default=default)
|
||||
def field_env_key(key,suffix=None):return '_'.join(filter(None,map(str.upper,['OPENLLM',suffix.strip('_') if suffix else '',key])))
|
||||
def get_debug_mode():return check_bool_env(DEBUG_ENV_VAR,False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG
|
||||
|
||||
return first_not_none(*(callback(k) for k in env_key), default=default)
|
||||
|
||||
|
||||
def field_env_key(key, suffix=None):
|
||||
return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key])))
|
||||
|
||||
|
||||
def get_debug_mode():
|
||||
return check_bool_env(DEBUG_ENV_VAR, False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG
|
||||
|
||||
|
||||
def get_quiet_mode():
|
||||
if QUIET_ENV_VAR in os.environ:return check_bool_env(QUIET_ENV_VAR, False)
|
||||
if DEBUG:return False
|
||||
if QUIET_ENV_VAR in os.environ:
|
||||
return check_bool_env(QUIET_ENV_VAR, False)
|
||||
if DEBUG:
|
||||
return False
|
||||
return False
|
||||
def get_disable_warnings():return check_bool_env(WARNING_ENV_VAR, False)
|
||||
|
||||
|
||||
def get_disable_warnings():
|
||||
return check_bool_env(WARNING_ENV_VAR, False)
|
||||
|
||||
|
||||
def set_disable_warnings(disable=True):
|
||||
if disable:os.environ[WARNING_ENV_VAR]=str(disable)
|
||||
def set_debug_mode(enabled,level=1):
|
||||
if enabled:os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
os.environ.update({DEBUG_ENV_VAR:str(enabled),QUIET_ENV_VAR:str(not enabled),_GRPC_DEBUG_ENV_VAR:'DEBUG' if enabled else 'ERROR','CT2_VERBOSE':'3'})
|
||||
if disable:
|
||||
os.environ[WARNING_ENV_VAR] = str(disable)
|
||||
|
||||
|
||||
def set_debug_mode(enabled, level=1):
|
||||
if enabled:
|
||||
os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
os.environ.update(
|
||||
{
|
||||
DEBUG_ENV_VAR: str(enabled),
|
||||
QUIET_ENV_VAR: str(not enabled),
|
||||
_GRPC_DEBUG_ENV_VAR: 'DEBUG' if enabled else 'ERROR',
|
||||
'CT2_VERBOSE': '3',
|
||||
}
|
||||
)
|
||||
set_disable_warnings(not enabled)
|
||||
|
||||
|
||||
def set_quiet_mode(enabled):
|
||||
os.environ.update({QUIET_ENV_VAR:str(enabled),DEBUG_ENV_VAR:str(not enabled),_GRPC_DEBUG_ENV_VAR:'NONE','CT2_VERBOSE':'-1'})
|
||||
os.environ.update(
|
||||
{QUIET_ENV_VAR: str(enabled), DEBUG_ENV_VAR: str(not enabled), _GRPC_DEBUG_ENV_VAR: 'NONE', 'CT2_VERBOSE': '-1'}
|
||||
)
|
||||
set_disable_warnings(enabled)
|
||||
def gen_random_uuid(prefix:str|None=None)->str:return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)])
|
||||
# NOTE: `compose` any number of unary functions into a single unary function
|
||||
|
||||
|
||||
def gen_random_uuid(prefix: str | None = None) -> str:
|
||||
return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)])
|
||||
|
||||
|
||||
# NOTE: `compose` any number of unary functions into a single unary function
|
||||
# compose(f, g, h)(x) == f(g(h(x))); compose(f, g, h)(x, y, z) == f(g(h(x, y, z)))
|
||||
def compose(*funcs):return functools.reduce(lambda f1,f2:lambda *args,**kwargs:f1(f2(*args,**kwargs)),funcs)
|
||||
# NOTE: `apply` a transform function that is invoked on results returned from the decorated function
|
||||
def compose(*funcs):
|
||||
return functools.reduce(lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)), funcs)
|
||||
|
||||
|
||||
# NOTE: `apply` a transform function that is invoked on results returned from the decorated function
|
||||
# apply(reversed)(func)(*args, **kwargs) == reversed(func(*args, **kwargs))
|
||||
def apply(transform):return lambda func:functools.wraps(func)(compose(transform,func))
|
||||
def validate_is_path(maybe_path):return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
def first_not_none(*args,default=None):return next((arg for arg in args if arg is not None),default)
|
||||
def apply(transform):
|
||||
return lambda func: functools.wraps(func)(compose(transform, func))
|
||||
|
||||
|
||||
def validate_is_path(maybe_path):
|
||||
return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
|
||||
|
||||
def first_not_none(*args, default=None):
|
||||
return next((arg for arg in args if arg is not None), default)
|
||||
|
||||
|
||||
def generate_context(framework_name):
|
||||
from bentoml._internal.models.model import ModelContext
|
||||
framework_versions={'transformers':pkg.get_pkg_version('transformers'),'safetensors':pkg.get_pkg_version('safetensors'),'optimum':pkg.get_pkg_version('optimum'),'accelerate':pkg.get_pkg_version('accelerate')}
|
||||
if iutils.is_torch_available():framework_versions['torch']=pkg.get_pkg_version('torch')
|
||||
if iutils.is_ctranslate_available():framework_versions['ctranslate2']=pkg.get_pkg_version('ctranslate2')
|
||||
if iutils.is_vllm_available():framework_versions['vllm']=pkg.get_pkg_version('vllm')
|
||||
if iutils.is_autoawq_available():framework_versions['autoawq']=pkg.get_pkg_version('autoawq')
|
||||
if iutils.is_autogptq_available():framework_versions['autogptq']=pkg.get_pkg_version('auto_gptq')
|
||||
if iutils.is_bentoml_available():framework_versions['bentoml']=pkg.get_pkg_version('bentoml')
|
||||
if iutils.is_triton_available():framework_versions['triton']=pkg.get_pkg_version('triton')
|
||||
return ModelContext(framework_name=framework_name,framework_versions=framework_versions)
|
||||
|
||||
framework_versions = {
|
||||
'transformers': pkg.get_pkg_version('transformers'),
|
||||
'safetensors': pkg.get_pkg_version('safetensors'),
|
||||
'optimum': pkg.get_pkg_version('optimum'),
|
||||
'accelerate': pkg.get_pkg_version('accelerate'),
|
||||
}
|
||||
if iutils.is_torch_available():
|
||||
framework_versions['torch'] = pkg.get_pkg_version('torch')
|
||||
if iutils.is_ctranslate_available():
|
||||
framework_versions['ctranslate2'] = pkg.get_pkg_version('ctranslate2')
|
||||
if iutils.is_vllm_available():
|
||||
framework_versions['vllm'] = pkg.get_pkg_version('vllm')
|
||||
if iutils.is_autoawq_available():
|
||||
framework_versions['autoawq'] = pkg.get_pkg_version('autoawq')
|
||||
if iutils.is_autogptq_available():
|
||||
framework_versions['autogptq'] = pkg.get_pkg_version('auto_gptq')
|
||||
if iutils.is_bentoml_available():
|
||||
framework_versions['bentoml'] = pkg.get_pkg_version('bentoml')
|
||||
if iutils.is_triton_available():
|
||||
framework_versions['triton'] = pkg.get_pkg_version('triton')
|
||||
return ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def in_notebook():
|
||||
try:from IPython.core.getipython import get_ipython;return 'IPKernelApp' in get_ipython().config
|
||||
except Exception:return False
|
||||
try:
|
||||
from IPython.core.getipython import get_ipython
|
||||
|
||||
return 'IPKernelApp' in get_ipython().config
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
_TOKENIZER_PREFIX = '_tokenizer_'
|
||||
|
||||
|
||||
def flatten_attrs(**attrs):
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX):]:v for k,v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
for k in tuple(attrs.keys()):
|
||||
if k.startswith(_TOKENIZER_PREFIX):del attrs[k]
|
||||
return attrs,tokenizer_attrs
|
||||
if k.startswith(_TOKENIZER_PREFIX):
|
||||
del attrs[k]
|
||||
return attrs, tokenizer_attrs
|
||||
|
||||
|
||||
# Special debug flag controled via DEBUG
|
||||
DEBUG=sys.flags.dev_mode or (not sys.flags.ignore_environment and check_bool_env(DEV_DEBUG_VAR, default=False))
|
||||
DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and check_bool_env(DEV_DEBUG_VAR, default=False))
|
||||
# Whether to show the codenge for debug purposes
|
||||
SHOW_CODEGEN=DEBUG and (os.environ.get(DEV_DEBUG_VAR,str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR,str(0)))>3)
|
||||
SHOW_CODEGEN = DEBUG and (
|
||||
os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3
|
||||
)
|
||||
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
|
||||
MYPY=False
|
||||
MYPY = False
|
||||
|
||||
|
||||
class ExceptionFilter(logging.Filter):
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
if exclude_exceptions is None:exclude_exceptions=[]
|
||||
if exclude_exceptions is None:
|
||||
exclude_exceptions = []
|
||||
try:
|
||||
from circus.exc import ConflictError
|
||||
if ConflictError not in exclude_exceptions:exclude_exceptions.append(ConflictError)
|
||||
|
||||
if ConflictError not in exclude_exceptions:
|
||||
exclude_exceptions.append(ConflictError)
|
||||
except ImportError:
|
||||
pass
|
||||
super(ExceptionFilter, self).__init__(**kwargs)
|
||||
self.EXCLUDE_EXCEPTIONS=exclude_exceptions
|
||||
def filter(self,record:logging.LogRecord)->bool:
|
||||
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.exc_info:
|
||||
etype,_,_=record.exc_info
|
||||
etype, _, _ = record.exc_info
|
||||
if etype is not None:
|
||||
for exc in self.EXCLUDE_EXCEPTIONS:
|
||||
if issubclass(etype, exc):return False
|
||||
if issubclass(etype, exc):
|
||||
return False
|
||||
return True
|
||||
# Used to filter out INFO log
|
||||
|
||||
|
||||
class InfoFilter(logging.Filter):
|
||||
def filter(self,record:logging.LogRecord)->bool:return logging.INFO<=record.levelno<logging.WARNING
|
||||
class WarningFilter(logging.Filter): # FIXME: Why does this not work?
|
||||
def filter(self,record:logging.LogRecord)->bool:
|
||||
if get_disable_warnings():return record.levelno>=logging.ERROR
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return logging.INFO <= record.levelno < logging.WARNING
|
||||
|
||||
|
||||
class WarningFilter(logging.Filter): # FIXME: Why does this not work?
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if get_disable_warnings():
|
||||
return record.levelno >= logging.ERROR
|
||||
return True
|
||||
# fmt: on
|
||||
|
||||
|
||||
_LOGGING_CONFIG: dict[str, t.Any] = {
|
||||
_LOGGING_CONFIG = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': True,
|
||||
'filters': {
|
||||
@@ -184,7 +299,7 @@ def configure_logging():
|
||||
|
||||
# XXX: define all classes, functions import above this line
|
||||
# since _extras will be the locals() import from this file.
|
||||
_extras: dict[str, t.Any] = {
|
||||
_extras = {
|
||||
**{
|
||||
k: v
|
||||
for k, v in locals().items()
|
||||
|
||||
@@ -86,14 +86,14 @@ def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> An
|
||||
|
||||
|
||||
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]:
|
||||
"""Create a tuple subclass to hold class attributes.
|
||||
'''Create a tuple subclass to hold class attributes.
|
||||
|
||||
The subclass is a bare tuple with properties for names.
|
||||
|
||||
class MyClassAttributes(tuple):
|
||||
__slots__ = ()
|
||||
x = property(itemgetter(0))
|
||||
"""
|
||||
'''
|
||||
from . import SHOW_CODEGEN
|
||||
|
||||
attr_class_name = f'{cls_name}Attributes'
|
||||
|
||||
@@ -210,14 +210,14 @@ def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]:
|
||||
|
||||
|
||||
def is_typing(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a module-like type.
|
||||
'''Checks whether the current type is a module-like type.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
bool: true if the type is itself a type
|
||||
"""
|
||||
'''
|
||||
raw = t.get_origin(field_type)
|
||||
if raw is None:
|
||||
return False
|
||||
@@ -227,7 +227,7 @@ def is_typing(field_type: type) -> bool:
|
||||
|
||||
|
||||
def is_literal(field_type: type) -> bool:
|
||||
"""Checks whether the given field type is a Literal type or not.
|
||||
'''Checks whether the given field type is a Literal type or not.
|
||||
|
||||
Literals are weird: isinstance and subclass do not work, so you compare
|
||||
the origin with the Literal declaration itself.
|
||||
@@ -237,7 +237,7 @@ def is_literal(field_type: type) -> bool:
|
||||
|
||||
Returns:
|
||||
bool: true if Literal type, false otherwise
|
||||
"""
|
||||
'''
|
||||
origin = t.get_origin(field_type)
|
||||
return origin is not None and origin is t.Literal
|
||||
|
||||
@@ -272,12 +272,12 @@ class EnumChoice(click.Choice):
|
||||
name = 'enum'
|
||||
|
||||
def __init__(self, enum: Enum, case_sensitive: bool = False):
|
||||
"""Enum type support for click that extends ``click.Choice``.
|
||||
'''Enum type support for click that extends ``click.Choice``.
|
||||
|
||||
Args:
|
||||
enum: Given enum
|
||||
case_sensitive: Whether this choice should be case case_sensitive.
|
||||
"""
|
||||
'''
|
||||
self.mapping = enum
|
||||
self.internal_type = type(enum)
|
||||
choices: list[t.Any] = [e.name for e in enum.__class__]
|
||||
@@ -296,7 +296,7 @@ class LiteralChoice(EnumChoice):
|
||||
name = 'literal'
|
||||
|
||||
def __init__(self, value: t.Any, case_sensitive: bool = False):
|
||||
"""Literal support for click."""
|
||||
'''Literal support for click.'''
|
||||
# expect every literal value to belong to the same primitive type
|
||||
values = list(value.__args__)
|
||||
item_type = type(values[0])
|
||||
@@ -334,14 +334,14 @@ def allows_multiple(field_type: type[t.Any]) -> bool:
|
||||
|
||||
|
||||
def is_mapping(field_type: type) -> bool:
|
||||
"""Checks whether this field represents a dictionary or JSON object.
|
||||
'''Checks whether this field represents a dictionary or JSON object.
|
||||
|
||||
Args:
|
||||
field_type (type): pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true when the field is a dict-like object, false otherwise.
|
||||
"""
|
||||
'''
|
||||
# Early out for standard containers.
|
||||
from . import lenient_issubclass
|
||||
|
||||
@@ -379,14 +379,14 @@ def is_container(field_type: type) -> bool:
|
||||
|
||||
|
||||
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType, ...]:
|
||||
"""Parses the arguments inside a container type (lists, tuples and so on).
|
||||
'''Parses the arguments inside a container type (lists, tuples and so on).
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
ParamType | tuple[ParamType]: single click-compatible type or a tuple
|
||||
"""
|
||||
'''
|
||||
if not is_container(field_type):
|
||||
raise ValueError('Field type is not a container type.')
|
||||
args = t.get_args(field_type)
|
||||
@@ -466,7 +466,7 @@ class CudaValueType(ParamType):
|
||||
return var
|
||||
|
||||
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
|
||||
'''Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
|
||||
|
||||
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
|
||||
|
||||
@@ -474,7 +474,7 @@ class CudaValueType(ParamType):
|
||||
ctx: Invocation context for this command.
|
||||
param: The parameter that is requesting completion.
|
||||
incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
'''
|
||||
from openllm.utils import available_devices
|
||||
|
||||
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
|
||||
|
||||
@@ -1,53 +1,110 @@
|
||||
# fmt: off
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import os
|
||||
import importlib, importlib.metadata, importlib.util, os
|
||||
|
||||
OPTIONAL_DEPENDENCIES={'vllm','fine-tune','ggml','ctranslate','agents','openai','playground','gptq','grpc','awq'}
|
||||
ENV_VARS_TRUE_VALUES={'1','ON','YES','TRUE'}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES=ENV_VARS_TRUE_VALUES.union({'AUTO'})
|
||||
USE_VLLM=os.getenv('USE_VLLM','AUTO').upper()
|
||||
def _is_package_available(package:str)->bool:
|
||||
_package_available=importlib.util.find_spec(package) is not None
|
||||
OPTIONAL_DEPENDENCIES = {
|
||||
'vllm',
|
||||
'fine-tune',
|
||||
'ggml',
|
||||
'ctranslate',
|
||||
'agents',
|
||||
'openai',
|
||||
'playground',
|
||||
'gptq',
|
||||
'grpc',
|
||||
'awq',
|
||||
}
|
||||
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
|
||||
USE_VLLM = os.getenv('USE_VLLM', 'AUTO').upper()
|
||||
|
||||
|
||||
def _is_package_available(package: str) -> bool:
|
||||
_package_available = importlib.util.find_spec(package) is not None
|
||||
if _package_available:
|
||||
try:importlib.metadata.version(package)
|
||||
except importlib.metadata.PackageNotFoundError:_package_available=False
|
||||
try:
|
||||
importlib.metadata.version(package)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_package_available = False
|
||||
return _package_available
|
||||
_ctranslate_available=importlib.util.find_spec('ctranslate2') is not None
|
||||
_vllm_available=importlib.util.find_spec('vllm') is not None
|
||||
_grpc_available=importlib.util.find_spec('grpc') is not None
|
||||
_autoawq_available=importlib.util.find_spec('awq') is not None
|
||||
_triton_available=importlib.util.find_spec('triton') is not None
|
||||
_torch_available=_is_package_available('torch')
|
||||
_transformers_available=_is_package_available('transformers')
|
||||
_bentoml_available=_is_package_available('bentoml')
|
||||
_peft_available=_is_package_available('peft')
|
||||
_bitsandbytes_available=_is_package_available('bitsandbytes')
|
||||
_jupyter_available=_is_package_available('jupyter')
|
||||
_jupytext_available=_is_package_available('jupytext')
|
||||
_notebook_available=_is_package_available('notebook')
|
||||
_autogptq_available=_is_package_available('auto_gptq')
|
||||
def is_triton_available()->bool:return _triton_available
|
||||
def is_ctranslate_available()->bool:return _ctranslate_available
|
||||
def is_bentoml_available()->bool:return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml
|
||||
def is_transformers_available()->bool:return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers
|
||||
def is_grpc_available()->bool:return _grpc_available
|
||||
def is_jupyter_available()->bool:return _jupyter_available
|
||||
def is_jupytext_available()->bool:return _jupytext_available
|
||||
def is_notebook_available()->bool:return _notebook_available
|
||||
def is_peft_available()->bool:return _peft_available
|
||||
def is_bitsandbytes_available()->bool:return _bitsandbytes_available
|
||||
def is_autogptq_available()->bool:return _autogptq_available
|
||||
def is_torch_available()->bool:return _torch_available
|
||||
def is_autoawq_available()->bool:
|
||||
|
||||
|
||||
_ctranslate_available = importlib.util.find_spec('ctranslate2') is not None
|
||||
_vllm_available = importlib.util.find_spec('vllm') is not None
|
||||
_grpc_available = importlib.util.find_spec('grpc') is not None
|
||||
_autoawq_available = importlib.util.find_spec('awq') is not None
|
||||
_triton_available = importlib.util.find_spec('triton') is not None
|
||||
_torch_available = _is_package_available('torch')
|
||||
_transformers_available = _is_package_available('transformers')
|
||||
_bentoml_available = _is_package_available('bentoml')
|
||||
_peft_available = _is_package_available('peft')
|
||||
_bitsandbytes_available = _is_package_available('bitsandbytes')
|
||||
_jupyter_available = _is_package_available('jupyter')
|
||||
_jupytext_available = _is_package_available('jupytext')
|
||||
_notebook_available = _is_package_available('notebook')
|
||||
_autogptq_available = _is_package_available('auto_gptq')
|
||||
|
||||
|
||||
def is_triton_available() -> bool:
|
||||
return _triton_available
|
||||
|
||||
|
||||
def is_ctranslate_available() -> bool:
|
||||
return _ctranslate_available
|
||||
|
||||
|
||||
def is_bentoml_available() -> bool:
|
||||
return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml
|
||||
|
||||
|
||||
def is_transformers_available() -> bool:
|
||||
return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers
|
||||
|
||||
|
||||
def is_grpc_available() -> bool:
|
||||
return _grpc_available
|
||||
|
||||
|
||||
def is_jupyter_available() -> bool:
|
||||
return _jupyter_available
|
||||
|
||||
|
||||
def is_jupytext_available() -> bool:
|
||||
return _jupytext_available
|
||||
|
||||
|
||||
def is_notebook_available() -> bool:
|
||||
return _notebook_available
|
||||
|
||||
|
||||
def is_peft_available() -> bool:
|
||||
return _peft_available
|
||||
|
||||
|
||||
def is_bitsandbytes_available() -> bool:
|
||||
return _bitsandbytes_available
|
||||
|
||||
|
||||
def is_autogptq_available() -> bool:
|
||||
return _autogptq_available
|
||||
|
||||
|
||||
def is_torch_available() -> bool:
|
||||
return _torch_available
|
||||
|
||||
|
||||
def is_autoawq_available() -> bool:
|
||||
global _autoawq_available
|
||||
try:importlib.metadata.version('autoawq')
|
||||
except importlib.metadata.PackageNotFoundError:_autoawq_available=False
|
||||
try:
|
||||
importlib.metadata.version('autoawq')
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_autoawq_available = False
|
||||
return _autoawq_available
|
||||
def is_vllm_available()->bool:
|
||||
|
||||
|
||||
def is_vllm_available() -> bool:
|
||||
global _vllm_available
|
||||
if USE_VLLM in ENV_VARS_TRUE_AND_AUTO_VALUES or _vllm_available:
|
||||
try:importlib.metadata.version('vllm')
|
||||
except importlib.metadata.PackageNotFoundError:_vllm_available=False
|
||||
try:
|
||||
importlib.metadata.version('vllm')
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_vllm_available = False
|
||||
return _vllm_available
|
||||
|
||||
@@ -23,11 +23,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LazyLoader(types.ModuleType):
|
||||
"""
|
||||
'''
|
||||
LazyLoader module borrowed from Tensorflow https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/util/lazy_loader.py with a addition of "module caching". This will throw an exception if module cannot be imported.
|
||||
|
||||
Lazily import a module, mainly to avoid pulling in large dependencies. `contrib`, and `ffmpeg` are examples of modules that are large and not always needed, and this allows them to only be loaded when they are used.
|
||||
"""
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -180,12 +180,12 @@ class LazyModule(types.ModuleType):
|
||||
return result + [i for i in self.__all__ if i not in result]
|
||||
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
"""Equivocal __getattr__ implementation.
|
||||
'''Equivocal __getattr__ implementation.
|
||||
|
||||
It checks from _objects > _modules and does it recursively.
|
||||
|
||||
It also contains a special case for all of the metadata information, such as __version__ and __version_info__.
|
||||
"""
|
||||
'''
|
||||
if name in _reserved_namespace:
|
||||
raise openllm_core.exceptions.ForbiddenAttributeError(
|
||||
f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified."
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""OpenLLM.
|
||||
'''OpenLLM.
|
||||
===========
|
||||
|
||||
An open platform for operating large language models in production.
|
||||
@@ -8,13 +8,11 @@ Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
* Option to bring your own fine-tuned LLMs
|
||||
* Online Serving with HTTP, gRPC, SSE or custom API
|
||||
* Native integration with BentoML, LangChain, OpenAI compatible endpoints, LlamaIndex for custom LLM apps
|
||||
"""
|
||||
'''
|
||||
|
||||
# fmt: off
|
||||
# update-config-stubs.py: import stubs start
|
||||
from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING,CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES,AutoConfig as AutoConfig,BaichuanConfig as BaichuanConfig,ChatGLMConfig as ChatGLMConfig,DollyV2Config as DollyV2Config,FalconConfig as FalconConfig,FlanT5Config as FlanT5Config,GPTNeoXConfig as GPTNeoXConfig,LlamaConfig as LlamaConfig,MistralConfig as MistralConfig,MPTConfig as MPTConfig,OPTConfig as OPTConfig,PhiConfig as PhiConfig,StableLMConfig as StableLMConfig,StarCoderConfig as StarCoderConfig,YiConfig as YiConfig
|
||||
from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MistralConfig as MistralConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, PhiConfig as PhiConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, YiConfig as YiConfig
|
||||
# update-config-stubs.py: import stubs stop
|
||||
# fmt: on
|
||||
|
||||
from openllm_cli._sdk import (
|
||||
build as build,
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
# fmt: off
|
||||
if __name__ == '__main__':from openllm_cli.entrypoint import cli;cli()
|
||||
if __name__ == '__main__':
|
||||
from openllm_cli.entrypoint import cli
|
||||
|
||||
cli()
|
||||
|
||||
@@ -59,7 +59,7 @@ def Runner(
|
||||
"'ensure_available=False' won't have any effect as LLM will always check to download the model on initialisation."
|
||||
)
|
||||
model_id = attrs.get('model_id', os.getenv('OPENLLM_MODEL_ID', llm_config['default_id']))
|
||||
_RUNNER_MSG = f"""\
|
||||
_RUNNER_MSG = f'''\
|
||||
Using 'openllm.Runner' is now deprecated. Make sure to switch to the following syntax:
|
||||
|
||||
```python
|
||||
@@ -71,7 +71,7 @@ def Runner(
|
||||
async def chat(input: str) -> str:
|
||||
async for it in llm.generate_iterator(input): print(it)
|
||||
```
|
||||
"""
|
||||
'''
|
||||
warnings.warn(_RUNNER_MSG, DeprecationWarning, stacklevel=2)
|
||||
attrs.update(
|
||||
{
|
||||
|
||||
@@ -33,7 +33,7 @@ def is_sentence_complete(output):
|
||||
|
||||
|
||||
def is_partial_stop(output, stop_str):
|
||||
"""Check whether the output contains a partial stop str."""
|
||||
'''Check whether the output contains a partial stop str.'''
|
||||
for i in range(min(len(output), len(stop_str))):
|
||||
if stop_str.startswith(output[-i:]):
|
||||
return True
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import functools, logging, os, warnings
|
||||
import typing as t
|
||||
import attr, inflection, orjson
|
||||
import bentoml, openllm
|
||||
import functools, logging, os, warnings, typing as t
|
||||
import attr, inflection, orjson, bentoml, openllm
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import (
|
||||
AdapterMap,
|
||||
@@ -20,9 +18,9 @@ from openllm_core._typing_compat import (
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.utils import (
|
||||
DEBUG,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
ReprMixin,
|
||||
apply,
|
||||
check_bool_env,
|
||||
codegen,
|
||||
first_not_none,
|
||||
flatten_attrs,
|
||||
@@ -142,31 +140,33 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
|
||||
# NOTE: If you are here to see how generate_iterator and generate works, see above.
|
||||
# The below are mainly for internal implementation that you don't have to worry about.
|
||||
# fmt: off
|
||||
|
||||
_model_id:str
|
||||
_revision:t.Optional[str]
|
||||
_quantization_config:t.Optional[t.Union[transformers.BitsAndBytesConfig,transformers.GPTQConfig,transformers.AwqConfig]]
|
||||
_model_id: str
|
||||
_revision: t.Optional[str]
|
||||
_quantization_config: t.Optional[
|
||||
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
|
||||
]
|
||||
_quantise: t.Optional[LiteralQuantise]
|
||||
_model_decls:TupleAny
|
||||
__model_attrs:DictStrAny
|
||||
__tokenizer_attrs:DictStrAny
|
||||
_tag:bentoml.Tag
|
||||
_adapter_map:t.Optional[AdapterMap]
|
||||
_serialisation:LiteralSerialisation
|
||||
_local:bool
|
||||
_max_model_len:t.Optional[int]
|
||||
_model_decls: TupleAny
|
||||
__model_attrs: DictStrAny
|
||||
__tokenizer_attrs: DictStrAny
|
||||
_tag: bentoml.Tag
|
||||
_adapter_map: t.Optional[AdapterMap]
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_max_model_len: t.Optional[int]
|
||||
|
||||
__llm_dtype__: t.Union[LiteralDtype,t.Literal['auto', 'half', 'float']]='auto'
|
||||
__llm_torch_dtype__:'torch.dtype'=None
|
||||
__llm_config__:t.Optional[LLMConfig]=None
|
||||
__llm_backend__:LiteralBackend=None
|
||||
__llm_quantization_config__:t.Optional[t.Union[transformers.BitsAndBytesConfig,transformers.GPTQConfig,transformers.AwqConfig]]=None
|
||||
__llm_runner__:t.Optional[Runner[M, T]]=None
|
||||
__llm_model__:t.Optional[M]=None
|
||||
__llm_tokenizer__:t.Optional[T]=None
|
||||
__llm_adapter_map__:t.Optional[ResolvedAdapterMap]=None
|
||||
__llm_trust_remote_code__:bool=False
|
||||
__llm_dtype__: t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']] = 'auto'
|
||||
__llm_torch_dtype__: 'torch.dtype' = None
|
||||
__llm_config__: t.Optional[LLMConfig] = None
|
||||
__llm_backend__: LiteralBackend = None
|
||||
__llm_quantization_config__: t.Optional[
|
||||
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
|
||||
] = None
|
||||
__llm_runner__: t.Optional[Runner[M, T]] = None
|
||||
__llm_model__: t.Optional[M] = None
|
||||
__llm_tokenizer__: t.Optional[T] = None
|
||||
__llm_adapter_map__: t.Optional[ResolvedAdapterMap] = None
|
||||
__llm_trust_remote_code__: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -188,26 +188,34 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_eager=True,
|
||||
**attrs,
|
||||
):
|
||||
torch_dtype=attrs.pop('torch_dtype',None) # backward compatible
|
||||
if torch_dtype is not None:warnings.warns('The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.',DeprecationWarning,stacklevel=3);dtype=torch_dtype
|
||||
torch_dtype = attrs.pop('torch_dtype', None) # backward compatible
|
||||
if torch_dtype is not None:
|
||||
warnings.warns(
|
||||
'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
dtype = torch_dtype
|
||||
_local = False
|
||||
if validate_is_path(model_id):model_id,_local=resolve_filepath(model_id),True
|
||||
backend=first_not_none(getenv('backend',default=backend),default=self._cascade_backend())
|
||||
dtype=first_not_none(getenv('dtype',default=dtype,var=['TORCH_DTYPE']),default='auto')
|
||||
quantize=first_not_none(getenv('quantize',default=quantize,var=['QUANITSE']),default=None)
|
||||
attrs.update({'low_cpu_mem_usage':low_cpu_mem_usage})
|
||||
if validate_is_path(model_id):
|
||||
model_id, _local = resolve_filepath(model_id), True
|
||||
backend = first_not_none(getenv('backend', default=backend), default=self._cascade_backend())
|
||||
dtype = first_not_none(getenv('dtype', default=dtype, var=['TORCH_DTYPE']), default='auto')
|
||||
quantize = first_not_none(getenv('quantize', default=quantize, var=['QUANITSE']), default=None)
|
||||
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
|
||||
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
|
||||
model_attrs,tokenizer_attrs=flatten_attrs(**attrs)
|
||||
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)
|
||||
if model_tag is None:
|
||||
model_tag,model_version=self._make_tag_components(model_id,model_version,backend=backend)
|
||||
if model_version:model_tag=f'{model_tag}:{model_version}'
|
||||
model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend)
|
||||
if model_version:
|
||||
model_tag = f'{model_tag}:{model_version}'
|
||||
|
||||
self.__attrs_init__(
|
||||
model_id=model_id,
|
||||
revision=model_version,
|
||||
tag=bentoml.Tag.from_taglike(model_tag),
|
||||
quantization_config=quantization_config,
|
||||
quantise=getattr(self._Quantise,backend)(self,quantize),
|
||||
quantise=getattr(self._Quantise, backend)(self, quantize),
|
||||
model_decls=args,
|
||||
adapter_map=convert_peft_config_type(adapter_map) if adapter_map is not None else None,
|
||||
serialisation=serialisation,
|
||||
@@ -220,143 +228,248 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
llm_config__=llm_config,
|
||||
llm_trust_remote_code__=trust_remote_code,
|
||||
)
|
||||
|
||||
if _eager:
|
||||
try:
|
||||
model=bentoml.models.get(self.tag)
|
||||
model = bentoml.models.get(self.tag)
|
||||
except bentoml.exceptions.NotFound:
|
||||
model=openllm.serialisation.import_model(self,trust_remote_code=self.trust_remote_code)
|
||||
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
|
||||
# resolve the tag
|
||||
self._tag=model.tag
|
||||
if not _eager and embedded:raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
|
||||
if embedded:logger.warning('Models will be loaded into memory. NOT RECOMMENDED in production and SHOULD ONLY used for development.');self.runner.init_local(quiet=True)
|
||||
self._tag = model.tag
|
||||
if not _eager and embedded:
|
||||
raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
|
||||
if embedded:
|
||||
logger.warning(
|
||||
'NOT RECOMMENDED in production and SHOULD ONLY used for development (Loading into current memory).'
|
||||
)
|
||||
self.runner.init_local(quiet=True)
|
||||
|
||||
class _Quantise:
|
||||
@staticmethod
|
||||
def pt(llm:LLM,quantise=None):return quantise
|
||||
@staticmethod
|
||||
def vllm(llm:LLM,quantise=None):return quantise
|
||||
@staticmethod
|
||||
def ctranslate(llm:LLM,quantise=None):
|
||||
if quantise in {'int4','awq','gptq','squeezellm'}:raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
|
||||
if quantise=='int8':quantise='int8_float16' if llm._has_gpus else 'int8_float32'
|
||||
def pt(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
@apply(lambda val:tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self,model_id:str,model_version:str|None,backend:str)->tuple[str,str|None]:
|
||||
model_id,*maybe_revision=model_id.rsplit(':')
|
||||
if len(maybe_revision)>0:
|
||||
if model_version is not None:logger.warning("revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.",maybe_revision[0],model_version)
|
||||
|
||||
@staticmethod
|
||||
def vllm(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
|
||||
@staticmethod
|
||||
def ctranslate(llm: LLM, quantise=None):
|
||||
if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}:
|
||||
raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
|
||||
if quantise == 'int8':
|
||||
quantise = 'int8_float16' if llm._has_gpus else 'int8_float32'
|
||||
return quantise
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id: str, model_version: str | None, backend: str) -> tuple[str, str | None]:
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
if len(maybe_revision) > 0:
|
||||
if model_version is not None:
|
||||
logger.warning(
|
||||
"revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version
|
||||
)
|
||||
model_version = maybe_revision[0]
|
||||
if validate_is_path(model_id):model_id,model_version=resolve_filepath(model_id),first_not_none(model_version,default=generate_hash_from_file(model_id))
|
||||
return f'{backend}-{normalise_model_name(model_id)}',model_version
|
||||
if validate_is_path(model_id):
|
||||
model_id, model_version = (
|
||||
resolve_filepath(model_id),
|
||||
first_not_none(model_version, default=generate_hash_from_file(model_id)),
|
||||
)
|
||||
return f'{backend}-{normalise_model_name(model_id)}', model_version
|
||||
|
||||
@functools.cached_property
|
||||
def _has_gpus(self):
|
||||
try:
|
||||
from cuda import cuda
|
||||
err,*_=cuda.cuInit(0)
|
||||
if err!=cuda.CUresult.CUDA_SUCCESS:raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
err,num_gpus=cuda.cuDeviceGetCount()
|
||||
if err!=cuda.CUresult.CUDA_SUCCESS:raise RuntimeError('Failed to get CUDA device count.')
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
err, num_gpus = cuda.cuDeviceGetCount()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to get CUDA device count.')
|
||||
return True
|
||||
except (ImportError, RuntimeError):return False
|
||||
except (ImportError, RuntimeError):
|
||||
return False
|
||||
|
||||
@property
|
||||
def _torch_dtype(self):
|
||||
import torch, transformers
|
||||
_map=_torch_dtype_mapping()
|
||||
if not isinstance(self.__llm_torch_dtype__,torch.dtype):
|
||||
try:hf_config=transformers.AutoConfig.from_pretrained(self.bentomodel.path,trust_remote_code=self.trust_remote_code)
|
||||
except OpenLLMException:hf_config=transformers.AutoConfig.from_pretrained(self.model_id,trust_remote_code=self.trust_remote_code)
|
||||
config_dtype=getattr(hf_config,'torch_dtype',None)
|
||||
if config_dtype is None:config_dtype=torch.float32
|
||||
if self.__llm_dtype__=='auto':
|
||||
if config_dtype==torch.float32:torch_dtype=torch.float16
|
||||
else:torch_dtype=config_dtype
|
||||
|
||||
_map = _torch_dtype_mapping()
|
||||
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
|
||||
try:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(
|
||||
self.bentomodel.path, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
except OpenLLMException:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
config_dtype = getattr(hf_config, 'torch_dtype', None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
if self.__llm_dtype__ == 'auto':
|
||||
if config_dtype == torch.float32:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if self.__llm_dtype__ not in _map:raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
|
||||
torch_dtype=_map[self.__llm_dtype__]
|
||||
self.__llm_torch_dtype__=torch_dtype
|
||||
if self.__llm_dtype__ not in _map:
|
||||
raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
|
||||
torch_dtype = _map[self.__llm_dtype__]
|
||||
self.__llm_torch_dtype__ = torch_dtype
|
||||
return self.__llm_torch_dtype__
|
||||
|
||||
@property
|
||||
def _model_attrs(self):return {**self.import_kwargs[0],**self.__model_attrs}
|
||||
def _model_attrs(self):
|
||||
return {**self.import_kwargs[0], **self.__model_attrs}
|
||||
|
||||
@_model_attrs.setter
|
||||
def _model_attrs(self, value):self.__model_attrs = value
|
||||
def _model_attrs(self, value):
|
||||
self.__model_attrs = value
|
||||
|
||||
@property
|
||||
def _tokenizer_attrs(self):return {**self.import_kwargs[1],**self.__tokenizer_attrs}
|
||||
def _cascade_backend(self)->LiteralBackend:
|
||||
def _tokenizer_attrs(self):
|
||||
return {**self.import_kwargs[1], **self.__tokenizer_attrs}
|
||||
|
||||
def _cascade_backend(self) -> LiteralBackend:
|
||||
if self._has_gpus:
|
||||
if is_vllm_available():return 'vllm'
|
||||
elif is_ctranslate_available():return 'ctranslate' # XXX: base OpenLLM image should always include vLLM
|
||||
elif is_ctranslate_available():return 'ctranslate'
|
||||
else:return 'pt'
|
||||
def __setattr__(self,attr,value):
|
||||
if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
if is_vllm_available():
|
||||
return 'vllm'
|
||||
elif is_ctranslate_available():
|
||||
return 'ctranslate' # XXX: base OpenLLM image should always include vLLM
|
||||
elif is_ctranslate_available():
|
||||
return 'ctranslate'
|
||||
else:
|
||||
return 'pt'
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}:
|
||||
raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
super().__setattr__(attr, value)
|
||||
def __del__(self):del self.__llm_model__,self.__llm_tokenizer__,self.__llm_adapter_map__
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def __repr_keys__(self):return {'model_id','revision','backend','type'}
|
||||
def __repr_keys__(self):
|
||||
return {'model_id', 'revision', 'backend', 'type'}
|
||||
|
||||
def __repr_args__(self):
|
||||
yield 'model_id',self._model_id if not self._local else self.tag.name
|
||||
yield 'revision',self._revision if self._revision else self.tag.version
|
||||
yield 'backend',self.__llm_backend__
|
||||
yield 'type',self.llm_type
|
||||
yield 'model_id', self._model_id if not self._local else self.tag.name
|
||||
yield 'revision', self._revision if self._revision else self.tag.version
|
||||
yield 'backend', self.__llm_backend__
|
||||
yield 'type', self.llm_type
|
||||
|
||||
@property
|
||||
def import_kwargs(self):return {'device_map':'auto' if self._has_gpus else None,'torch_dtype':self._torch_dtype},{'padding_side':'left','truncation_side':'left'}
|
||||
def import_kwargs(self):
|
||||
return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {
|
||||
'padding_side': 'left',
|
||||
'truncation_side': 'left',
|
||||
}
|
||||
|
||||
@property
|
||||
def trust_remote_code(self):
|
||||
env=os.getenv('TRUST_REMOTE_CODE')
|
||||
if env is not None:return str(env).upper() in ENV_VARS_TRUE_VALUES
|
||||
env = os.getenv('TRUST_REMOTE_CODE')
|
||||
if env is not None:
|
||||
check_bool_env('TRUST_REMOTE_CODE', env)
|
||||
return self.__llm_trust_remote_code__
|
||||
|
||||
@property
|
||||
def model_id(self):return self._model_id
|
||||
def model_id(self):
|
||||
return self._model_id
|
||||
|
||||
@property
|
||||
def revision(self):return self._revision
|
||||
def revision(self):
|
||||
return self._revision
|
||||
|
||||
@property
|
||||
def tag(self):return self._tag
|
||||
def tag(self):
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def bentomodel(self):return openllm.serialisation.get(self)
|
||||
def bentomodel(self):
|
||||
return openllm.serialisation.get(self)
|
||||
|
||||
@property
|
||||
def quantization_config(self):
|
||||
if self.__llm_quantization_config__ is None:
|
||||
from ._quantisation import infer_quantisation_config
|
||||
if self._quantization_config is not None:self.__llm_quantization_config__ = self._quantization_config
|
||||
elif self._quantise is not None:self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(self,self._quantise,**self._model_attrs)
|
||||
else:raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
|
||||
|
||||
if self._quantization_config is not None:
|
||||
self.__llm_quantization_config__ = self._quantization_config
|
||||
elif self._quantise is not None:
|
||||
self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(
|
||||
self, self._quantise, **self._model_attrs
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
|
||||
return self.__llm_quantization_config__
|
||||
|
||||
@property
|
||||
def has_adapters(self):return self._adapter_map is not None
|
||||
def has_adapters(self):
|
||||
return self._adapter_map is not None
|
||||
|
||||
@property
|
||||
def local(self):return self._local
|
||||
def local(self):
|
||||
return self._local
|
||||
|
||||
@property
|
||||
def quantise(self):return self._quantise
|
||||
def quantise(self):
|
||||
return self._quantise
|
||||
|
||||
@property
|
||||
def llm_type(self):return normalise_model_name(self._model_id)
|
||||
def llm_type(self):
|
||||
return normalise_model_name(self._model_id)
|
||||
|
||||
@property
|
||||
def llm_parameters(self):return (self._model_decls,self._model_attrs),self._tokenizer_attrs
|
||||
def llm_parameters(self):
|
||||
return (self._model_decls, self._model_attrs), self._tokenizer_attrs
|
||||
|
||||
@property
|
||||
def identifying_params(self):return {'configuration':self.config.model_dump_json().decode(),'model_ids':orjson.dumps(self.config['model_ids']).decode(),'model_id':self.model_id}
|
||||
def identifying_params(self):
|
||||
return {
|
||||
'configuration': self.config.model_dump_json().decode(),
|
||||
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
|
||||
'model_id': self.model_id,
|
||||
}
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
if self.__llm_tokenizer__ is None:self.__llm_tokenizer__=openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
if self.__llm_tokenizer__ is None:
|
||||
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
return self.__llm_tokenizer__
|
||||
|
||||
@property
|
||||
def runner(self):
|
||||
from ._runners import runner
|
||||
if self.__llm_runner__ is None:self.__llm_runner__=runner(self)
|
||||
|
||||
if self.__llm_runner__ is None:
|
||||
self.__llm_runner__ = runner(self)
|
||||
return self.__llm_runner__
|
||||
def prepare(self,adapter_type='lora',use_gradient_checking=True,**attrs):
|
||||
if self.__llm_backend__!='pt':raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
|
||||
|
||||
def prepare(self, adapter_type='lora', use_gradient_checking=True, **attrs):
|
||||
if self.__llm_backend__ != 'pt':
|
||||
raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
|
||||
from peft.mapping import get_peft_model
|
||||
from peft.utils.other import prepare_model_for_kbit_training
|
||||
model=get_peft_model(
|
||||
prepare_model_for_kbit_training(self.model,use_gradient_checkpointing=use_gradient_checking),
|
||||
|
||||
model = get_peft_model(
|
||||
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking),
|
||||
self.config['fine_tune_strategies']
|
||||
.get(adapter_type,self.config.make_fine_tune_config(adapter_type))
|
||||
.get(adapter_type, self.config.make_fine_tune_config(adapter_type))
|
||||
.train()
|
||||
.with_config(**attrs)
|
||||
.build(),
|
||||
)
|
||||
if DEBUG:model.print_trainable_parameters()
|
||||
return model,self.tokenizer
|
||||
def prepare_for_training(self,*args,**attrs):logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Please use `prepare` instead.');return self.prepare(*args,**attrs)
|
||||
if DEBUG:
|
||||
model.print_trainable_parameters()
|
||||
return model, self.tokenizer
|
||||
|
||||
def prepare_for_training(self, *args, **attrs):
|
||||
logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Use `prepare` instead.')
|
||||
return self.prepare(*args, **attrs)
|
||||
|
||||
@property
|
||||
def adapter_map(self):
|
||||
@@ -431,33 +544,49 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
return self.__llm_config__
|
||||
|
||||
|
||||
# fmt: off
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _torch_dtype_mapping()->dict[str,torch.dtype]:
|
||||
import torch; return {
|
||||
def _torch_dtype_mapping() -> dict[str, torch.dtype]:
|
||||
import torch
|
||||
|
||||
return {
|
||||
'half': torch.float16,
|
||||
'float': torch.float32,
|
||||
'float16': torch.float16,
|
||||
'float': torch.float32,
|
||||
'float32': torch.float32,
|
||||
'bfloat16': torch.bfloat16,
|
||||
}
|
||||
def normalise_model_name(name:str)->str:return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/','--'))
|
||||
def convert_peft_config_type(adapter_map:dict[str, str])->AdapterMap:
|
||||
if not is_peft_available():raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
|
||||
|
||||
def normalise_model_name(name: str) -> str:
|
||||
return (
|
||||
os.path.basename(resolve_filepath(name))
|
||||
if validate_is_path(name)
|
||||
else inflection.dasherize(name.replace('/', '--'))
|
||||
)
|
||||
|
||||
|
||||
def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
if not is_peft_available():
|
||||
raise RuntimeError(
|
||||
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
resolved:AdapterMap={}
|
||||
resolved: AdapterMap = {}
|
||||
for path_or_adapter_id, name in adapter_map.items():
|
||||
if name is None:raise ValueError('Adapter name must be specified.')
|
||||
if name is None:
|
||||
raise ValueError('Adapter name must be specified.')
|
||||
if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)):
|
||||
config_file=os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
else:
|
||||
try:
|
||||
config_file=hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
except Exception as err:
|
||||
raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") from err
|
||||
with open(config_file, 'r') as file:resolved_config=orjson.loads(file.read())
|
||||
_peft_type=resolved_config['peft_type'].lower()
|
||||
if _peft_type not in resolved:resolved[_peft_type]=()
|
||||
resolved[_peft_type]+=(_AdapterTuple((path_or_adapter_id, name, resolved_config)),)
|
||||
with open(config_file, 'r') as file:
|
||||
resolved_config = orjson.loads(file.read())
|
||||
_peft_type = resolved_config['peft_type'].lower()
|
||||
if _peft_type not in resolved:
|
||||
resolved[_peft_type] = ()
|
||||
resolved[_peft_type] += (_AdapterTuple((path_or_adapter_id, name, resolved_config)),)
|
||||
return resolved
|
||||
|
||||
@@ -1,2 +1,9 @@
|
||||
# fmt: off
|
||||
import os,orjson,openllm_core.utils as coreutils;model_id,model_tag,adapter_map,serialization,trust_remote_code=os.environ['OPENLLM_MODEL_ID'],None,orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP',orjson.dumps(None))),os.getenv('OPENLLM_SERIALIZATION',default='safetensors'),coreutils.check_bool_env('TRUST_REMOTE_CODE',False)
|
||||
import os, orjson, openllm_core.utils as coreutils
|
||||
|
||||
model_id, model_tag, adapter_map, serialization, trust_remote_code = (
|
||||
os.environ['OPENLLM_MODEL_ID'],
|
||||
None,
|
||||
orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None))),
|
||||
os.getenv('OPENLLM_SERIALIZATION', default='safetensors'),
|
||||
coreutils.check_bool_env('TRUST_REMOTE_CODE', False),
|
||||
)
|
||||
|
||||
@@ -215,18 +215,18 @@ def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[
|
||||
NvidiaGpuResource = _make_resource_class(
|
||||
'NvidiaGpuResource',
|
||||
'nvidia.com/gpu',
|
||||
"""NVIDIA GPU resource.
|
||||
'''NVIDIA GPU resource.
|
||||
|
||||
This is a modified version of internal's BentoML's NvidiaGpuResource
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.""",
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.''',
|
||||
)
|
||||
AmdGpuResource = _make_resource_class(
|
||||
'AmdGpuResource',
|
||||
'amd.com/gpu',
|
||||
"""AMD GPU resource.
|
||||
'''AMD GPU resource.
|
||||
|
||||
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""",
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.''',
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,10 +19,10 @@ class CascadingResourceStrategy:
|
||||
resource_request: Optional[Dict[str, Any]],
|
||||
workers_per_resource: float,
|
||||
) -> int:
|
||||
"""Return the number of workers to be used for the given runnable class.
|
||||
'''Return the number of workers to be used for the given runnable class.
|
||||
|
||||
Note that for all available GPU, the number of workers will always be 1.
|
||||
"""
|
||||
'''
|
||||
@classmethod
|
||||
def get_worker_env(
|
||||
cls,
|
||||
@@ -31,16 +31,16 @@ class CascadingResourceStrategy:
|
||||
workers_per_resource: Union[int, float],
|
||||
worker_index: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get worker env for this given worker_index.
|
||||
'''Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
runnable_class: The runnable class to be run.
|
||||
resource_request: The resource request of the runnable.
|
||||
workers_per_resource: # of workers per resource.
|
||||
worker_index: The index of the worker, start from 0.
|
||||
"""
|
||||
'''
|
||||
@staticmethod
|
||||
def transpile_workers_to_cuda_envvar(
|
||||
workers_per_resource: Union[float, int], gpus: List[str], worker_index: int
|
||||
) -> str:
|
||||
"""Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string."""
|
||||
'''Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.'''
|
||||
|
||||
@@ -22,7 +22,7 @@ OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD'
|
||||
|
||||
|
||||
def build_editable(path, package='openllm'):
|
||||
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
|
||||
'''Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set.'''
|
||||
if not check_bool_env(OPENLLM_DEV_BUILD, default=False):
|
||||
return None
|
||||
# We need to build the package in editable mode, so that we can import it
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# syntax=docker/dockerfile-upstream:master
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
FROM python:3.9-slim-bullseye as base-container
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base-container
|
||||
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
@@ -15,23 +13,32 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
ccache \
|
||||
curl \
|
||||
libssl-dev ca-certificates make \
|
||||
git && \
|
||||
git python3-pip && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN mkdir -p /openllm-python
|
||||
RUN mkdir -p /openllm-core
|
||||
RUN mkdir -p /openllm-client
|
||||
|
||||
# Install required dependencies
|
||||
COPY openllm-python/src src
|
||||
COPY hatch.toml README.md CHANGELOG.md openllm-python/pyproject.toml ./
|
||||
COPY openllm-python/src /openllm-python/src
|
||||
COPY hatch.toml README.md CHANGELOG.md openllm-python/pyproject.toml /openllm-python/
|
||||
|
||||
# Install all required dependencies
|
||||
# We have to install autoawq first to avoid conflict with torch, then reinstall torch with vllm
|
||||
# below
|
||||
# pip install autoawq --no-cache-dir && \
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --extra-index-url "https://huggingface.github.io/autogptq-index/whl/cu118/" \
|
||||
-v --no-cache-dir \
|
||||
pip3 install -v --no-cache-dir \
|
||||
"ray==2.6.0" "vllm==0.2.2" xformers && \
|
||||
pip install --no-cache-dir -e .
|
||||
pip3 install --no-cache-dir -e .
|
||||
|
||||
COPY openllm-core/src openllm-core/src
|
||||
COPY hatch.toml README.md CHANGELOG.md openllm-core/pyproject.toml /openllm-core/
|
||||
RUN --mount=type=cache,target=/root/.cache/pip pip3 install -v --no-cache-dir -e /openllm-core/
|
||||
|
||||
COPY openllm-client/src openllm-client/src
|
||||
COPY hatch.toml README.md CHANGELOG.md openllm-client/pyproject.toml /openllm-client/
|
||||
RUN --mount=type=cache,target=/root/.cache/pip pip3 install -v --no-cache-dir -e /openllm-client/
|
||||
|
||||
FROM base-container
|
||||
|
||||
|
||||
@@ -50,13 +50,17 @@ class RefResolver:
|
||||
else:
|
||||
raise ValueError(f'Unknown strategy: {strategy_or_version}')
|
||||
|
||||
# fmt: off
|
||||
@property
|
||||
def tag(self):return 'latest' if self.strategy in {'latest','nightly'} else repr(self.version)
|
||||
def tag(self):
|
||||
return 'latest' if self.strategy in {'latest', 'nightly'} else repr(self.version)
|
||||
|
||||
@staticmethod
|
||||
def construct_base_image(reg,strategy=None):
|
||||
if reg == 'gh': logger.warning("Setting base registry to 'gh' will affect cold start performance on GCP/AWS.")
|
||||
elif reg == 'docker': logger.warning('docker is base image is yet to be supported. Falling back to "ecr".'); reg = 'ecr'
|
||||
def construct_base_image(reg, strategy=None):
|
||||
if reg == 'gh':
|
||||
logger.warning("Setting base registry to 'gh' will affect cold start performance on GCP/AWS.")
|
||||
elif reg == 'docker':
|
||||
logger.warning('docker is base image is yet to be supported. Falling back to "ecr".')
|
||||
reg = 'ecr'
|
||||
return f'{_CONTAINER_REGISTRY[reg]}:{RefResolver.from_strategy(strategy).tag}'
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
# fmt: off
|
||||
def __dir__():import openllm_client as _client;return sorted(dir(_client))
|
||||
def __getattr__(it):import openllm_client as _client;return getattr(_client, it)
|
||||
def __dir__():
|
||||
import openllm_client as _client
|
||||
|
||||
return sorted(dir(_client))
|
||||
|
||||
|
||||
def __getattr__(it):
|
||||
import openllm_client as _client
|
||||
|
||||
return getattr(_client, it)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""OpenLLM Python client.
|
||||
'''OpenLLM Python client.
|
||||
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
client.query("What is the difference between gather and scatter?")
|
||||
```
|
||||
"""
|
||||
'''
|
||||
|
||||
from openllm_client import AsyncHTTPClient as AsyncHTTPClient, HTTPClient as HTTPClient
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Entrypoint for all third-party apps.
|
||||
'''Entrypoint for all third-party apps.
|
||||
|
||||
Currently support OpenAI, Cohere compatible API.
|
||||
|
||||
Each module should implement the following API:
|
||||
|
||||
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
|
||||
"""
|
||||
'''
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
@@ -11,7 +11,7 @@ from openllm_core.utils import first_not_none
|
||||
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODELS_SCHEMA = """\
|
||||
LIST_MODELS_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -41,8 +41,8 @@ responses:
|
||||
owned_by: 'na'
|
||||
schema:
|
||||
$ref: '#/components/schemas/ModelList'
|
||||
"""
|
||||
CHAT_COMPLETIONS_SCHEMA = """\
|
||||
'''
|
||||
CHAT_COMPLETIONS_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -179,8 +179,8 @@ responses:
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
"""
|
||||
COMPLETIONS_SCHEMA = """\
|
||||
'''
|
||||
COMPLETIONS_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -332,8 +332,8 @@ responses:
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
"""
|
||||
HF_AGENT_SCHEMA = """\
|
||||
'''
|
||||
HF_AGENT_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -377,8 +377,8 @@ responses:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
"""
|
||||
HF_ADAPTERS_SCHEMA = """\
|
||||
'''
|
||||
HF_ADAPTERS_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -408,8 +408,8 @@ responses:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
"""
|
||||
COHERE_GENERATE_SCHEMA = """\
|
||||
'''
|
||||
COHERE_GENERATE_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -453,8 +453,8 @@ requestBody:
|
||||
stop_sequences:
|
||||
- "\\n"
|
||||
- "<|endoftext|>"
|
||||
"""
|
||||
COHERE_CHAT_SCHEMA = """\
|
||||
'''
|
||||
COHERE_CHAT_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -467,7 +467,7 @@ tags:
|
||||
- Cohere
|
||||
x-bentoml-name: cohere_chat
|
||||
summary: Creates a model response for the given chat conversation.
|
||||
"""
|
||||
'''
|
||||
|
||||
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
|
||||
|
||||
|
||||
@@ -14,11 +14,11 @@ P = ParamSpec('P')
|
||||
|
||||
|
||||
def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
'''Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
'''
|
||||
import cloudpickle
|
||||
import fs
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Serialisation utilities for OpenLLM.
|
||||
'''Serialisation utilities for OpenLLM.
|
||||
|
||||
Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
"""
|
||||
'''
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# fmt: off
|
||||
import functools, importlib.metadata, openllm_core
|
||||
|
||||
__all__ = ['generate_labels', 'available_devices', 'device_count']
|
||||
|
||||
|
||||
def generate_labels(llm):
|
||||
return {
|
||||
'backend': llm.__llm_backend__,
|
||||
@@ -11,10 +12,25 @@ def generate_labels(llm):
|
||||
'serialisation': llm._serialisation,
|
||||
**{package: importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}},
|
||||
}
|
||||
def available_devices():from ._strategies import NvidiaGpuResource;return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
|
||||
def available_devices():
|
||||
from ._strategies import NvidiaGpuResource
|
||||
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def device_count()->int:return len(available_devices())
|
||||
def __dir__():coreutils=set(dir(openllm_core.utils))|set([it for it in openllm_core.utils._extras if not it.startswith('_')]);return sorted(__all__)+sorted(list(coreutils))
|
||||
def device_count() -> int:
|
||||
return len(available_devices())
|
||||
|
||||
|
||||
def __dir__():
|
||||
coreutils = set(dir(openllm_core.utils)) | set([it for it in openllm_core.utils._extras if not it.startswith('_')])
|
||||
return sorted(__all__) + sorted(list(coreutils))
|
||||
|
||||
|
||||
def __getattr__(it):
|
||||
if hasattr(openllm_core.utils, it):return getattr(openllm_core.utils, it)
|
||||
if hasattr(openllm_core.utils, it):
|
||||
return getattr(openllm_core.utils, it)
|
||||
raise AttributeError(f'module {__name__} has no attribute {it}')
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""OpenLLM CLI.
|
||||
'''OpenLLM CLI.
|
||||
|
||||
For more information see ``openllm -h``.
|
||||
"""
|
||||
'''
|
||||
|
||||
@@ -146,7 +146,7 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
|
||||
backend_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
'LLM Optimization Options',
|
||||
help="""Optimization related options.
|
||||
help='''Optimization related options.
|
||||
|
||||
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
|
||||
@@ -154,7 +154,7 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
""",
|
||||
''',
|
||||
),
|
||||
quantize_option(factory=cog.optgroup),
|
||||
serialisation_option(factory=cog.optgroup),
|
||||
@@ -196,7 +196,7 @@ _IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
|
||||
|
||||
|
||||
def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
|
||||
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`."""
|
||||
'''Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`.'''
|
||||
from bentoml_cli.cli import cli
|
||||
|
||||
command = 'serve' if not serve_grpc else 'serve-grpc'
|
||||
@@ -233,11 +233,11 @@ _http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args
|
||||
|
||||
|
||||
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
"""General ``@click`` decorator with some sauce.
|
||||
'''General ``@click`` decorator with some sauce.
|
||||
|
||||
This decorator extends the default ``@click.option`` plus a factory option and factory attr to
|
||||
provide type-safe click.option or click.argument wrapper for all compatible factory.
|
||||
"""
|
||||
'''
|
||||
factory = attrs.pop('factory', click)
|
||||
factory_attr = attrs.pop('attr', 'option')
|
||||
if factory_attr != 'argument':
|
||||
@@ -346,7 +346,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
help='''Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
@@ -361,15 +361,15 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
"""
|
||||
'''
|
||||
+ (
|
||||
"""
|
||||
> [!NOTE] that this will set the mode for serving within deployment."""
|
||||
'''
|
||||
> [!NOTE] that this will set the mode for serving within deployment.'''
|
||||
if build
|
||||
else ''
|
||||
)
|
||||
+ """
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.""",
|
||||
+ '''
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.''',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
@@ -383,7 +383,7 @@ def workers_per_resource_option(
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help="""Number of workers per resource assigned.
|
||||
help='''Number of workers per resource assigned.
|
||||
|
||||
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -393,7 +393,7 @@ def workers_per_resource_option(
|
||||
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
|
||||
- ``conserved``: This will determine the number of available GPU resources. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
|
||||
"""
|
||||
'''
|
||||
+ (
|
||||
"""\n
|
||||
> [!NOTE] The workers value passed into 'build' will determine how the LLM can
|
||||
@@ -416,7 +416,7 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help="""Serialisation format for save/load LLM.
|
||||
help='''Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
@@ -425,7 +425,7 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
|
||||
|
||||
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
||||
""",
|
||||
''',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
@@ -291,7 +291,7 @@ def _import_model(
|
||||
|
||||
|
||||
def _list_models() -> dict[str, t.Any]:
|
||||
"""List all available models within the local store."""
|
||||
'''List all available models within the local store.'''
|
||||
from .entrypoint import models_command
|
||||
|
||||
return models_command.main(args=['--quiet'], standalone_mode=False)
|
||||
|
||||
@@ -94,14 +94,14 @@ else:
|
||||
|
||||
P = ParamSpec('P')
|
||||
logger = logging.getLogger('openllm')
|
||||
OPENLLM_FIGLET = """\
|
||||
OPENLLM_FIGLET = '''\
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
|
||||
"""
|
||||
'''
|
||||
|
||||
ServeCommand = t.Literal['serve', 'serve-grpc']
|
||||
|
||||
@@ -287,7 +287,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
return decorator
|
||||
|
||||
def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
||||
"""Additional format methods that include extensions as well as the default cli command."""
|
||||
'''Additional format methods that include extensions as well as the default cli command.'''
|
||||
from gettext import gettext as _
|
||||
|
||||
commands: list[tuple[str, click.Command]] = []
|
||||
@@ -334,7 +334,7 @@ _PACKAGE_NAME = 'openllm'
|
||||
message=f'{_PACKAGE_NAME}, %(version)s (compiled: {openllm.COMPILED})\nPython ({platform.python_implementation()}) {platform.python_version()}',
|
||||
)
|
||||
def cli() -> None:
|
||||
"""\b
|
||||
'''\b
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
@@ -345,7 +345,7 @@ def cli() -> None:
|
||||
\b
|
||||
An open platform for operating large language models in production.
|
||||
Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
"""
|
||||
'''
|
||||
|
||||
|
||||
@cli.command(
|
||||
@@ -389,13 +389,13 @@ def start_command(
|
||||
max_model_len: int | None,
|
||||
**attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
"""Start any LLM as a REST server.
|
||||
'''Start any LLM as a REST server.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm <start|start-http> <model_id> --<options> ...
|
||||
```
|
||||
"""
|
||||
'''
|
||||
if model_id in openllm.CONFIG_MAPPING:
|
||||
_model_name = model_id
|
||||
if deprecated_model_id is not None:
|
||||
@@ -519,13 +519,13 @@ def start_grpc_command(
|
||||
max_model_len: int | None,
|
||||
**attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
"""Start any LLM as a gRPC server.
|
||||
'''Start any LLM as a gRPC server.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm start-grpc <model_id> --<options> ...
|
||||
```
|
||||
"""
|
||||
'''
|
||||
termui.warning(
|
||||
'Continuous batching is currently not yet supported with gPRC. If you want to use continuous batching with gRPC, feel free to open a GitHub issue about your usecase.\n'
|
||||
)
|
||||
@@ -955,7 +955,7 @@ def build_command(
|
||||
force_push: bool,
|
||||
**_: t.Any,
|
||||
) -> BuildBentoOutput:
|
||||
"""Package a given models into a BentoLLM.
|
||||
'''Package a given models into a BentoLLM.
|
||||
|
||||
\b
|
||||
```bash
|
||||
@@ -971,7 +971,7 @@ def build_command(
|
||||
> [!IMPORTANT]
|
||||
> To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment
|
||||
> target also use the same Python version and architecture as build machine.
|
||||
"""
|
||||
'''
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
if model_id in openllm.CONFIG_MAPPING:
|
||||
@@ -1167,13 +1167,13 @@ class ModelItem(t.TypedDict):
|
||||
@cli.command()
|
||||
@click.option('--show-available', is_flag=True, default=True, hidden=True)
|
||||
def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]:
|
||||
"""List all supported models.
|
||||
'''List all supported models.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm models
|
||||
```
|
||||
"""
|
||||
'''
|
||||
result: dict[t.LiteralString, ModelItem] = {
|
||||
m: ModelItem(
|
||||
architecture=config.__openllm_architecture__,
|
||||
@@ -1216,11 +1216,11 @@ def prune_command(
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
**_: t.Any,
|
||||
) -> None:
|
||||
"""Remove all saved models, and bentos built with OpenLLM locally.
|
||||
'''Remove all saved models, and bentos built with OpenLLM locally.
|
||||
|
||||
\b
|
||||
If a model type is passed, then only prune models for that given model type.
|
||||
"""
|
||||
'''
|
||||
available: list[tuple[bentoml.Model | bentoml.Bento, ModelStore | BentoStore]] = [
|
||||
(m, model_store)
|
||||
for m in bentoml.models.list()
|
||||
@@ -1326,13 +1326,13 @@ def query_command(
|
||||
_memoized: DictStrAny,
|
||||
**_: t.Any,
|
||||
) -> None:
|
||||
"""Query a LLM interactively, from a terminal.
|
||||
'''Query a LLM interactively, from a terminal.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?"
|
||||
```
|
||||
"""
|
||||
'''
|
||||
if server_type == 'grpc':
|
||||
raise click.ClickException("'grpc' is currently disabled.")
|
||||
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
|
||||
@@ -1353,7 +1353,7 @@ def query_command(
|
||||
|
||||
@cli.group(cls=Extensions, hidden=True, name='extension')
|
||||
def extension_command() -> None:
|
||||
"""Extension for OpenLLM CLI."""
|
||||
'''Extension for OpenLLM CLI.'''
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -71,7 +71,7 @@ def build_container(
|
||||
@click.command(
|
||||
'build_base_container',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help="""Base image builder for BentoLLM.
|
||||
help='''Base image builder for BentoLLM.
|
||||
|
||||
By default, the base image will include custom kernels (PagedAttention via vllm, FlashAttention-v2, etc.) built with CUDA 11.8, Python 3.9 on Ubuntu22.04.
|
||||
Optionally, this can also be pushed directly to remote registry. Currently support ``docker.io``, ``ghcr.io`` and ``quay.io``.
|
||||
@@ -81,7 +81,7 @@ def build_container(
|
||||
This command is only useful for debugging and for building custom base image for extending BentoML with custom base images and custom kernels.
|
||||
|
||||
Note that we already release images on our CI to ECR and GHCR, so you don't need to build it yourself.
|
||||
""",
|
||||
''',
|
||||
)
|
||||
@container_registry_option
|
||||
@click.option(
|
||||
|
||||
@@ -24,7 +24,7 @@ if t.TYPE_CHECKING:
|
||||
def cli(
|
||||
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
'''Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path).'''
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
|
||||
@@ -13,7 +13,7 @@ from openllm_cli import termui
|
||||
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context) -> None:
|
||||
"""List available bentos built by OpenLLM."""
|
||||
'''List available bentos built by OpenLLM.'''
|
||||
mapping = {
|
||||
k: [
|
||||
{
|
||||
|
||||
@@ -18,7 +18,7 @@ if t.TYPE_CHECKING:
|
||||
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False, shell_complete=model_complete_envvar)
|
||||
def cli(model_name: str | None) -> DictStrAny:
|
||||
"""This is equivalent to openllm models --show-available less the nice table."""
|
||||
'''This is equivalent to openllm models --show-available less the nice table.'''
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@st.composite
|
||||
def model_settings(draw: st.DrawFn):
|
||||
"""Strategy for generating ModelSettings objects."""
|
||||
'''Strategy for generating ModelSettings objects.'''
|
||||
kwargs: dict[str, t.Any] = {
|
||||
'default_id': st.text(min_size=1),
|
||||
'model_ids': st.lists(st.text(), min_size=1),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import sys
|
||||
import os, sys
|
||||
from pathlib import Path
|
||||
|
||||
# currently we are assuming the indentatio level is 2 for comments
|
||||
@@ -110,7 +109,7 @@ def main() -> int:
|
||||
special_attrs_lines: list[str] = []
|
||||
for keys, ForwardRef in codegen.get_annotations(ModelSettings).items():
|
||||
special_attrs_lines.append(
|
||||
f"{' ' * 4}{keys}:{_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}\n"
|
||||
f"{' ' * 4}{keys}: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}\n"
|
||||
)
|
||||
# NOTE: inline stubs for _ConfigAttr type stubs
|
||||
config_attr_lines: list[str] = []
|
||||
@@ -119,7 +118,7 @@ def main() -> int:
|
||||
[
|
||||
' ' * 4 + line
|
||||
for line in [
|
||||
f'__openllm_{keys}__:{_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}=Field(None)\n',
|
||||
f'__openllm_{keys}__: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))} = Field(None)\n',
|
||||
f"'''{_value_docstring[keys]}'''\n",
|
||||
]
|
||||
]
|
||||
@@ -133,7 +132,7 @@ def main() -> int:
|
||||
' ' * 2 + line
|
||||
for line in [
|
||||
'@overload\n',
|
||||
f"def __getitem__(self,item:t.Literal['{keys}'])->{_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}:...\n",
|
||||
f"def __getitem__(self, item: t.Literal['{keys}']) -> {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n",
|
||||
]
|
||||
]
|
||||
)
|
||||
@@ -144,11 +143,11 @@ def main() -> int:
|
||||
' ' * 2 + line
|
||||
for line in [
|
||||
'@overload\n',
|
||||
"def __getitem__(self,item:t.Literal['generation_class'])->t.Type[openllm_core.GenerationConfig]:...\n",
|
||||
"def __getitem__(self, item: t.Literal['generation_class']) -> t.Type[openllm_core.GenerationConfig]: ...\n",
|
||||
'@overload\n',
|
||||
"def __getitem__(self,item:t.Literal['sampling_class'])->t.Type[openllm_core.SamplingParams]:...\n",
|
||||
"def __getitem__(self, item: t.Literal['sampling_class']) -> t.Type[openllm_core.SamplingParams]: ...\n",
|
||||
'@overload\n',
|
||||
"def __getitem__(self,item:t.Literal['extras'])->t.Dict[str, t.Any]:...\n",
|
||||
"def __getitem__(self, item: t.Literal['extras']) -> t.Dict[str, t.Any]: ...\n",
|
||||
]
|
||||
]
|
||||
)
|
||||
@@ -158,7 +157,7 @@ def main() -> int:
|
||||
lines.extend(
|
||||
[
|
||||
' ' * 2 + line
|
||||
for line in ['@overload\n', f"def __getitem__(self,item:t.Literal['{keys}'])->{type_pep563}:...\n"]
|
||||
for line in ['@overload\n', f"def __getitem__(self, item: t.Literal['{keys}']) -> {type_pep563}: ...\n"]
|
||||
]
|
||||
)
|
||||
lines.append(' ' * 2 + '# NOTE: SamplingParams arguments\n')
|
||||
@@ -167,7 +166,7 @@ def main() -> int:
|
||||
lines.extend(
|
||||
[
|
||||
' ' * 2 + line
|
||||
for line in ['@overload\n', f"def __getitem__(self,item:t.Literal['{keys}'])->{type_pep563}:...\n"]
|
||||
for line in ['@overload\n', f"def __getitem__(self, item: t.Literal['{keys}']) -> {type_pep563}: ...\n"]
|
||||
]
|
||||
)
|
||||
lines.append(' ' * 2 + '# NOTE: PeftType arguments\n')
|
||||
@@ -177,7 +176,7 @@ def main() -> int:
|
||||
' ' * 2 + line
|
||||
for line in [
|
||||
'@overload\n',
|
||||
f"def __getitem__(self,item:t.Literal['{keys.lower()}'])->t.Dict[str, t.Any]:...\n",
|
||||
f"def __getitem__(self, item: t.Literal['{keys.lower()}']) -> t.Dict[str, t.Any]: ...\n",
|
||||
]
|
||||
]
|
||||
)
|
||||
@@ -191,16 +190,11 @@ def main() -> int:
|
||||
+ [' ' * 2 + START_COMMENT, *lines, ' ' * 2 + END_COMMENT]
|
||||
+ processed[end_idx + 1 :]
|
||||
)
|
||||
with _TARGET_FILE.open('w') as f:
|
||||
f.writelines(processed)
|
||||
with _TARGET_FILE.open('w') as f: f.writelines(processed)
|
||||
|
||||
with _TARGET_AUTO_FILE.open('r') as f:
|
||||
processed = f.readlines()
|
||||
with _TARGET_AUTO_FILE.open('r') as f: processed = f.readlines()
|
||||
|
||||
start_auto_stubs_idx, end_auto_stubs_idx = (
|
||||
processed.index(' ' * 2 + START_AUTO_STUBS_COMMENT),
|
||||
processed.index(' ' * 2 + END_AUTO_STUBS_COMMENT),
|
||||
)
|
||||
start_auto_stubs_idx, end_auto_stubs_idx = processed.index(' ' * 2 + START_AUTO_STUBS_COMMENT), processed.index(' ' * 2 + END_AUTO_STUBS_COMMENT)
|
||||
lines = []
|
||||
for model, class_name in CONFIG_MAPPING_NAMES.items():
|
||||
lines.extend(
|
||||
@@ -209,35 +203,21 @@ def main() -> int:
|
||||
for line in [
|
||||
'@t.overload\n',
|
||||
'@classmethod\n',
|
||||
f"def for_model(cls,model_name:t.Literal['{model}'],**attrs:t.Any)->openllm_core.config.{class_name}:...\n",
|
||||
f"def for_model(cls, model_name: t.Literal['{model}'], **attrs: t.Any) -> openllm_core.config.{class_name}: ...\n",
|
||||
]
|
||||
]
|
||||
)
|
||||
processed = (
|
||||
processed[:start_auto_stubs_idx]
|
||||
+ [' ' * 2 + START_AUTO_STUBS_COMMENT, *lines, ' ' * 2 + END_AUTO_STUBS_COMMENT]
|
||||
+ processed[end_auto_stubs_idx + 1 :]
|
||||
)
|
||||
with _TARGET_AUTO_FILE.open('w') as f:
|
||||
f.writelines(processed)
|
||||
processed = processed[:start_auto_stubs_idx] + [' ' * 2 + START_AUTO_STUBS_COMMENT, *lines, ' ' * 2 + END_AUTO_STUBS_COMMENT] + processed[end_auto_stubs_idx + 1 :]
|
||||
with _TARGET_AUTO_FILE.open('w') as f: f.writelines(processed)
|
||||
|
||||
with _TARGET_INIT_FILE.open('r') as f:
|
||||
processed = f.readlines()
|
||||
start_import_stubs_idx, end_import_stubs_idx = (
|
||||
processed.index(START_IMPORT_STUBS_COMMENT),
|
||||
processed.index(END_IMPORT_STUBS_COMMENT),
|
||||
)
|
||||
lines = f'from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING,CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES,AutoConfig as AutoConfig,{",".join([a+" as "+a for a in CONFIG_MAPPING_NAMES.values()])}\n'
|
||||
processed = (
|
||||
processed[:start_import_stubs_idx]
|
||||
+ [START_IMPORT_STUBS_COMMENT, lines, END_IMPORT_STUBS_COMMENT]
|
||||
+ processed[end_import_stubs_idx + 1 :]
|
||||
)
|
||||
with _TARGET_INIT_FILE.open('w') as f:
|
||||
f.writelines(processed)
|
||||
with _TARGET_INIT_FILE.open('r') as f: processed = f.readlines()
|
||||
|
||||
start_import_stubs_idx, end_import_stubs_idx = processed.index(START_IMPORT_STUBS_COMMENT), processed.index(END_IMPORT_STUBS_COMMENT)
|
||||
lines = f'from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, {", ".join([a+" as "+a for a in CONFIG_MAPPING_NAMES.values()])}\n'
|
||||
processed = processed[:start_import_stubs_idx] + [START_IMPORT_STUBS_COMMENT, lines, END_IMPORT_STUBS_COMMENT] + processed[end_import_stubs_idx + 1 :]
|
||||
with _TARGET_INIT_FILE.open('w') as f: f.writelines(processed)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise SystemExit(main())
|
||||
if __name__ == '__main__': raise SystemExit(main())
|
||||
|
||||
Reference in New Issue
Block a user