mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-09 02:32:51 -04:00
fix(infra): conform ruff to 150 LL (#781)
Generally correctly format it with ruff format and manual style Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import logging as _logging, os as _os, pathlib as _pathlib, warnings as _warnings
|
||||
from openllm_cli import _sdk
|
||||
from . import utils as utils
|
||||
|
||||
if utils.DEBUG:
|
||||
utils.set_debug_mode(True); _logging.basicConfig(level=_logging.NOTSET)
|
||||
utils.set_debug_mode(True)
|
||||
_logging.basicConfig(level=_logging.NOTSET)
|
||||
else:
|
||||
# configuration for bitsandbytes before import
|
||||
_os.environ['BITSANDBYTES_NOWELCOME'] = _os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
@@ -30,8 +32,11 @@ __lazy = utils.LazyModule( # NOTE: update this to sys.modules[__name__] once my
|
||||
'_llm': ['LLM'],
|
||||
},
|
||||
extra_objects={
|
||||
'COMPILED': COMPILED, 'start': _sdk.start, 'build': _sdk.build, #
|
||||
'import_model': _sdk.import_model, 'list_models': _sdk.list_models, #
|
||||
'COMPILED': COMPILED,
|
||||
'start': _sdk.start,
|
||||
'build': _sdk.build, #
|
||||
'import_model': _sdk.import_model,
|
||||
'list_models': _sdk.list_models, #
|
||||
},
|
||||
)
|
||||
__all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
'''OpenLLM.
|
||||
"""OpenLLM.
|
||||
===========
|
||||
|
||||
An open platform for operating large language models in production.
|
||||
@@ -8,36 +8,17 @@ Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
* Option to bring your own fine-tuned LLMs
|
||||
* Online Serving with HTTP, gRPC, SSE or custom API
|
||||
* Native integration with BentoML, LangChain, OpenAI compatible endpoints, LlamaIndex for custom LLM apps
|
||||
'''
|
||||
"""
|
||||
|
||||
# update-config-stubs.py: import stubs start
|
||||
from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MistralConfig as MistralConfig, MixtralConfig as MixtralConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, PhiConfig as PhiConfig, QwenConfig as QwenConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, YiConfig as YiConfig
|
||||
# update-config-stubs.py: import stubs stop
|
||||
|
||||
from openllm_cli._sdk import (
|
||||
build as build,
|
||||
import_model as import_model,
|
||||
list_models as list_models,
|
||||
start as start,
|
||||
)
|
||||
from openllm_core._configuration import (
|
||||
GenerationConfig as GenerationConfig,
|
||||
LLMConfig as LLMConfig,
|
||||
SamplingParams as SamplingParams,
|
||||
)
|
||||
from openllm_core._schemas import (
|
||||
GenerationInput as GenerationInput,
|
||||
GenerationOutput as GenerationOutput,
|
||||
MetadataOutput as MetadataOutput,
|
||||
)
|
||||
from openllm_cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start
|
||||
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams
|
||||
from openllm_core._schemas import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, MetadataOutput as MetadataOutput
|
||||
|
||||
from . import (
|
||||
bundle as bundle,
|
||||
client as client,
|
||||
exceptions as exceptions,
|
||||
serialisation as serialisation,
|
||||
utils as utils,
|
||||
)
|
||||
from . import bundle as bundle, client as client, exceptions as exceptions, serialisation as serialisation, utils as utils
|
||||
from ._deprecated import Runner as Runner
|
||||
from ._llm import LLM as LLM
|
||||
from ._quantisation import infer_quantisation_config as infer_quantisation_config
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
if __name__ == '__main__': from openllm_cli.entrypoint import cli; cli()
|
||||
if __name__ == '__main__':
|
||||
from openllm_cli.entrypoint import cli
|
||||
|
||||
cli()
|
||||
|
||||
@@ -7,15 +7,22 @@ from openllm_core.utils import first_not_none, getenv, is_vllm_available
|
||||
__all__ = ['Runner']
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def Runner(
|
||||
model_name: str, ensure_available: bool = True, #
|
||||
init_local: bool = False, backend: LiteralBackend | None = None, #
|
||||
llm_config: openllm.LLMConfig | None = None, **attrs: t.Any,
|
||||
model_name: str,
|
||||
ensure_available: bool = True, #
|
||||
init_local: bool = False,
|
||||
backend: LiteralBackend | None = None, #
|
||||
llm_config: openllm.LLMConfig | None = None,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
if llm_config is None: llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
if not ensure_available: logger.warning("'ensure_available=False' won't have any effect as LLM will always check to download the model on initialisation.")
|
||||
if llm_config is None:
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
if not ensure_available:
|
||||
logger.warning("'ensure_available=False' won't have any effect as LLM will always check to download the model on initialisation.")
|
||||
model_id = attrs.get('model_id', os.getenv('OPENLLM_MODEL_ID', llm_config['default_id']))
|
||||
warnings.warn(f'''\
|
||||
warnings.warn(
|
||||
f"""\
|
||||
Using 'openllm.Runner' is now deprecated. Make sure to switch to the following syntax:
|
||||
|
||||
```python
|
||||
@@ -26,11 +33,15 @@ def Runner(
|
||||
@svc.api(...)
|
||||
async def chat(input: str) -> str:
|
||||
async for it in llm.generate_iterator(input): print(it)
|
||||
```''', DeprecationWarning, stacklevel=2)
|
||||
attrs.update(
|
||||
{
|
||||
'model_id': model_id, 'quantize': getenv('QUANTIZE', var=['QUANTISE'], default=attrs.get('quantize', None)), #
|
||||
'serialisation': getenv('serialization', default=attrs.get('serialisation', llm_config['serialisation']), var=['SERIALISATION']),
|
||||
}
|
||||
```""",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return openllm.LLM(backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'), llm_config=llm_config, embedded=init_local, **attrs).runner
|
||||
attrs.update({
|
||||
'model_id': model_id,
|
||||
'quantize': getenv('QUANTIZE', var=['QUANTISE'], default=attrs.get('quantize', None)), #
|
||||
'serialisation': getenv('serialization', default=attrs.get('serialisation', llm_config['serialisation']), var=['SERIALISATION']),
|
||||
})
|
||||
return openllm.LLM(
|
||||
backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'), llm_config=llm_config, embedded=init_local, **attrs
|
||||
).runner
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
def prepare_logits_processor(config):
|
||||
import transformers
|
||||
|
||||
generation_config = config.generation_config
|
||||
logits_processor = transformers.LogitsProcessorList()
|
||||
if generation_config['temperature'] >= 1e-5 and generation_config['temperature'] != 1.0:
|
||||
@@ -11,16 +12,27 @@ def prepare_logits_processor(config):
|
||||
if generation_config['top_k'] > 0:
|
||||
logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
|
||||
return logits_processor
|
||||
|
||||
|
||||
# NOTE: The ordering here is important. Some models have two of these and we have a preference for which value gets used.
|
||||
SEQLEN_KEYS = ['max_sequence_length', 'seq_length', 'max_position_embeddings', 'max_seq_len', 'model_max_length']
|
||||
|
||||
|
||||
def get_context_length(config):
|
||||
rope_scaling = getattr(config, 'rope_scaling', None)
|
||||
rope_scaling_factor = config.rope_scaling['factor'] if rope_scaling else 1.0
|
||||
for key in SEQLEN_KEYS:
|
||||
if getattr(config, key, None) is not None: return int(rope_scaling_factor * getattr(config, key))
|
||||
if getattr(config, key, None) is not None:
|
||||
return int(rope_scaling_factor * getattr(config, key))
|
||||
return 2048
|
||||
def is_sentence_complete(output): return output.endswith(('.', '?', '!', '...', '。', '?', '!', '…', '"', "'", '”'))
|
||||
|
||||
|
||||
def is_sentence_complete(output):
|
||||
return output.endswith(('.', '?', '!', '...', '。', '?', '!', '…', '"', "'", '”'))
|
||||
|
||||
|
||||
def is_partial_stop(output, stop_str):
|
||||
for i in range(min(len(output), len(stop_str))):
|
||||
if stop_str.startswith(output[-i:]): return True
|
||||
if stop_str.startswith(output[-i:]):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -48,7 +48,8 @@ ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]]
|
||||
@attr.define(slots=False, repr=False, init=False)
|
||||
class LLM(t.Generic[M, T]):
|
||||
async def generate(self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs):
|
||||
if adapter_name is not None and self.__llm_backend__ != 'pt': raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
|
||||
if adapter_name is not None and self.__llm_backend__ != 'pt':
|
||||
raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for result in self.generate_iterator(
|
||||
@@ -57,18 +58,18 @@ class LLM(t.Generic[M, T]):
|
||||
for output in result.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
if (final_result := result) is None: raise RuntimeError('No result is returned.')
|
||||
if (final_result := result) is None:
|
||||
raise RuntimeError('No result is returned.')
|
||||
return final_result.with_options(
|
||||
prompt=prompt,
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
],
|
||||
outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs],
|
||||
)
|
||||
|
||||
async def generate_iterator(self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs):
|
||||
from bentoml._internal.runner.runner_handle import DummyRunnerHandle
|
||||
if adapter_name is not None and self.__llm_backend__ != 'pt': raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
|
||||
|
||||
if adapter_name is not None and self.__llm_backend__ != 'pt':
|
||||
raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
|
||||
|
||||
if isinstance(self.runner._runner_handle, DummyRunnerHandle):
|
||||
if os.getenv('BENTO_PATH') is not None:
|
||||
@@ -79,10 +80,13 @@ class LLM(t.Generic[M, T]):
|
||||
|
||||
stop_token_ids = stop_token_ids or []
|
||||
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
|
||||
if eos_token_id and not isinstance(eos_token_id, list): eos_token_id = [eos_token_id]
|
||||
if eos_token_id and not isinstance(eos_token_id, list):
|
||||
eos_token_id = [eos_token_id]
|
||||
stop_token_ids.extend(eos_token_id or [])
|
||||
if (config_eos := config['eos_token_id']) and config_eos not in stop_token_ids: stop_token_ids.append(config_eos)
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
if (config_eos := config['eos_token_id']) and config_eos not in stop_token_ids:
|
||||
stop_token_ids.append(config_eos)
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
if stop is None:
|
||||
stop = set()
|
||||
elif isinstance(stop, str):
|
||||
@@ -90,16 +94,20 @@ class LLM(t.Generic[M, T]):
|
||||
else:
|
||||
stop = set(stop)
|
||||
for tid in stop_token_ids:
|
||||
if tid: stop.add(self.tokenizer.decode(tid))
|
||||
if tid:
|
||||
stop.add(self.tokenizer.decode(tid))
|
||||
|
||||
if prompt_token_ids is None:
|
||||
if prompt is None: raise ValueError('Either prompt or prompt_token_ids must be specified.')
|
||||
if prompt is None:
|
||||
raise ValueError('Either prompt or prompt_token_ids must be specified.')
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
|
||||
request_id = gen_random_uuid() if request_id is None else request_id
|
||||
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
|
||||
try:
|
||||
generator = self.runner.generate_iterator.async_stream(prompt_token_ids, request_id, stop=list(stop), adapter_name=adapter_name, **config.model_dump(flatten=True))
|
||||
generator = self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop=list(stop), adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Failed to start generation task: {err}') from err
|
||||
|
||||
@@ -118,11 +126,18 @@ class LLM(t.Generic[M, T]):
|
||||
|
||||
# NOTE: If you are here to see how generate_iterator and generate works, see above.
|
||||
# The below are mainly for internal implementation that you don't have to worry about.
|
||||
_model_id: str; _revision: t.Optional[str] #
|
||||
_model_id: str
|
||||
_revision: t.Optional[str] #
|
||||
_quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]]
|
||||
_quantise: t.Optional[LiteralQuantise]; _model_decls: t.Tuple[t.Any, ...]; __model_attrs: t.Dict[str, t.Any] #
|
||||
__tokenizer_attrs: t.Dict[str, t.Any]; _tag: bentoml.Tag; _adapter_map: t.Optional[AdapterMap] #
|
||||
_serialisation: LiteralSerialisation; _local: bool; _max_model_len: t.Optional[int] #
|
||||
_quantise: t.Optional[LiteralQuantise]
|
||||
_model_decls: t.Tuple[t.Any, ...]
|
||||
__model_attrs: t.Dict[str, t.Any] #
|
||||
__tokenizer_attrs: t.Dict[str, t.Any]
|
||||
_tag: bentoml.Tag
|
||||
_adapter_map: t.Optional[AdapterMap] #
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_max_model_len: t.Optional[int] #
|
||||
_gpu_memory_utilization: float
|
||||
|
||||
__llm_dtype__: t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']] = 'auto'
|
||||
@@ -159,20 +174,28 @@ class LLM(t.Generic[M, T]):
|
||||
):
|
||||
torch_dtype = attrs.pop('torch_dtype', None) # backward compatible
|
||||
if torch_dtype is not None:
|
||||
warnings.warn('The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.', DeprecationWarning, stacklevel=3); dtype = torch_dtype
|
||||
warnings.warn(
|
||||
'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.', DeprecationWarning, stacklevel=3
|
||||
)
|
||||
dtype = torch_dtype
|
||||
_local = False
|
||||
if validate_is_path(model_id): model_id, _local = resolve_filepath(model_id), True
|
||||
if validate_is_path(model_id):
|
||||
model_id, _local = resolve_filepath(model_id), True
|
||||
backend = getenv('backend', default=backend)
|
||||
if backend is None: backend = self._cascade_backend()
|
||||
if backend is None:
|
||||
backend = self._cascade_backend()
|
||||
dtype = getenv('dtype', default=dtype, var=['TORCH_DTYPE'])
|
||||
if dtype is None: logger.warning('Setting dtype to auto. Inferring from framework specific models'); dtype = 'auto'
|
||||
if dtype is None:
|
||||
logger.warning('Setting dtype to auto. Inferring from framework specific models')
|
||||
dtype = 'auto'
|
||||
quantize = getenv('quantize', default=quantize, var=['QUANITSE'])
|
||||
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
|
||||
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
|
||||
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)
|
||||
if model_tag is None:
|
||||
model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend)
|
||||
if model_version: model_tag = f'{model_tag}:{model_version}'
|
||||
if model_version:
|
||||
model_tag = f'{model_tag}:{model_version}'
|
||||
|
||||
self.__attrs_init__(
|
||||
model_id=model_id,
|
||||
@@ -200,42 +223,57 @@ class LLM(t.Generic[M, T]):
|
||||
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
|
||||
# resolve the tag
|
||||
self._tag = model.tag
|
||||
if not _eager and embedded: raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
|
||||
if not _eager and embedded:
|
||||
raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
|
||||
if embedded:
|
||||
logger.warning('NOT RECOMMENDED in production and SHOULD ONLY used for development.'); self.runner.init_local(quiet=True)
|
||||
logger.warning('NOT RECOMMENDED in production and SHOULD ONLY used for development.')
|
||||
self.runner.init_local(quiet=True)
|
||||
|
||||
class _Quantise:
|
||||
@staticmethod
|
||||
def pt(llm: LLM, quantise=None): return quantise
|
||||
def pt(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
|
||||
@staticmethod
|
||||
def vllm(llm: LLM, quantise=None): return quantise
|
||||
def vllm(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
|
||||
@staticmethod
|
||||
def ctranslate(llm: LLM, quantise=None):
|
||||
if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}: raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
|
||||
if quantise == 'int8': quantise = 'int8_float16' if llm._has_gpus else 'int8_float32'
|
||||
if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}:
|
||||
raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
|
||||
if quantise == 'int8':
|
||||
quantise = 'int8_float16' if llm._has_gpus else 'int8_float32'
|
||||
return quantise
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id: str, model_version: str | None, backend: str) -> tuple[str, str | None]:
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
if len(maybe_revision) > 0:
|
||||
if model_version is not None: logger.warning("revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version)
|
||||
if model_version is not None:
|
||||
logger.warning("revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version)
|
||||
model_version = maybe_revision[0]
|
||||
if validate_is_path(model_id):
|
||||
model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id))
|
||||
return f'{backend}-{normalise_model_name(model_id)}', model_version
|
||||
|
||||
@functools.cached_property
|
||||
def _has_gpus(self):
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
err, num_gpus = cuda.cuDeviceGetCount()
|
||||
err, _ = cuda.cuDeviceGetCount()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to get CUDA device count.')
|
||||
return True
|
||||
except (ImportError, RuntimeError):
|
||||
return False
|
||||
|
||||
@property
|
||||
def _torch_dtype(self):
|
||||
import torch, transformers
|
||||
|
||||
_map = _torch_dtype_mapping()
|
||||
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
|
||||
try:
|
||||
@@ -243,23 +281,32 @@ class LLM(t.Generic[M, T]):
|
||||
except OpenLLMException:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
config_dtype = getattr(hf_config, 'torch_dtype', None)
|
||||
if config_dtype is None: config_dtype = torch.float32
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
if self.__llm_dtype__ == 'auto':
|
||||
if config_dtype == torch.float32:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if self.__llm_dtype__ not in _map: raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
|
||||
if self.__llm_dtype__ not in _map:
|
||||
raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
|
||||
torch_dtype = _map[self.__llm_dtype__]
|
||||
self.__llm_torch_dtype__ = torch_dtype
|
||||
return self.__llm_torch_dtype__
|
||||
|
||||
@property
|
||||
def _model_attrs(self): return {**self.import_kwargs[0], **self.__model_attrs}
|
||||
def _model_attrs(self):
|
||||
return {**self.import_kwargs[0], **self.__model_attrs}
|
||||
|
||||
@_model_attrs.setter
|
||||
def _model_attrs(self, model_attrs): self.__model_attrs = model_attrs
|
||||
def _model_attrs(self, model_attrs):
|
||||
self.__model_attrs = model_attrs
|
||||
|
||||
@property
|
||||
def _tokenizer_attrs(self): return {**self.import_kwargs[1], **self.__tokenizer_attrs}
|
||||
def _tokenizer_attrs(self):
|
||||
return {**self.import_kwargs[1], **self.__tokenizer_attrs}
|
||||
|
||||
def _cascade_backend(self) -> LiteralBackend:
|
||||
logger.warning('It is recommended to specify the backend explicitly. Cascading backend might lead to unexpected behaviour.')
|
||||
if self._has_gpus:
|
||||
@@ -271,35 +318,61 @@ class LLM(t.Generic[M, T]):
|
||||
return 'ctranslate'
|
||||
else:
|
||||
return 'pt'
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}: raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}:
|
||||
raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__
|
||||
except AttributeError:
|
||||
pass
|
||||
def __repr_args__(self): yield from (('model_id', self._model_id if not self._local else self.tag.name), ('revision', self._revision if self._revision else self.tag.version), ('backend', self.__llm_backend__), ('type', self.llm_type))
|
||||
def __repr__(self) -> str: return f'{self.__class__.__name__} {orjson.dumps({k: v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
|
||||
|
||||
def __repr_args__(self):
|
||||
yield from (
|
||||
('model_id', self._model_id if not self._local else self.tag.name),
|
||||
('revision', self._revision if self._revision else self.tag.version),
|
||||
('backend', self.__llm_backend__),
|
||||
('type', self.llm_type),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__} {orjson.dumps({k: v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
|
||||
|
||||
@property
|
||||
def import_kwargs(self): return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
def import_kwargs(self):
|
||||
return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
|
||||
@property
|
||||
def trust_remote_code(self):
|
||||
env = os.getenv('TRUST_REMOTE_CODE')
|
||||
if env is not None: return check_bool_env('TRUST_REMOTE_CODE', env)
|
||||
if env is not None:
|
||||
return check_bool_env('TRUST_REMOTE_CODE', env)
|
||||
return self.__llm_trust_remote_code__
|
||||
|
||||
@property
|
||||
def model_id(self): return self._model_id
|
||||
def model_id(self):
|
||||
return self._model_id
|
||||
|
||||
@property
|
||||
def revision(self): return self._revision
|
||||
def revision(self):
|
||||
return self._revision
|
||||
|
||||
@property
|
||||
def tag(self): return self._tag
|
||||
def tag(self):
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def bentomodel(self): return openllm.serialisation.get(self)
|
||||
def bentomodel(self):
|
||||
return openllm.serialisation.get(self)
|
||||
|
||||
@property
|
||||
def quantization_config(self):
|
||||
if self.__llm_quantization_config__ is None:
|
||||
from ._quantisation import infer_quantisation_config
|
||||
|
||||
if self._quantization_config is not None:
|
||||
self.__llm_quantization_config__ = self._quantization_config
|
||||
elif self._quantise is not None:
|
||||
@@ -307,16 +380,27 @@ class LLM(t.Generic[M, T]):
|
||||
else:
|
||||
raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
|
||||
return self.__llm_quantization_config__
|
||||
|
||||
@property
|
||||
def has_adapters(self): return self._adapter_map is not None
|
||||
def has_adapters(self):
|
||||
return self._adapter_map is not None
|
||||
|
||||
@property
|
||||
def local(self): return self._local
|
||||
def local(self):
|
||||
return self._local
|
||||
|
||||
@property
|
||||
def quantise(self): return self._quantise
|
||||
def quantise(self):
|
||||
return self._quantise
|
||||
|
||||
@property
|
||||
def llm_type(self): return normalise_model_name(self._model_id)
|
||||
def llm_type(self):
|
||||
return normalise_model_name(self._model_id)
|
||||
|
||||
@property
|
||||
def llm_parameters(self): return (self._model_decls, self._model_attrs), self._tokenizer_attrs
|
||||
def llm_parameters(self):
|
||||
return (self._model_decls, self._model_attrs), self._tokenizer_attrs
|
||||
|
||||
@property
|
||||
def identifying_params(self):
|
||||
return {
|
||||
@@ -324,43 +408,55 @@ class LLM(t.Generic[M, T]):
|
||||
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
|
||||
'model_id': self.model_id,
|
||||
}
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
if self.__llm_tokenizer__ is None: self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
if self.__llm_tokenizer__ is None:
|
||||
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
return self.__llm_tokenizer__
|
||||
|
||||
@property
|
||||
def runner(self):
|
||||
from ._runners import runner
|
||||
if self.__llm_runner__ is None: self.__llm_runner__ = runner(self)
|
||||
|
||||
if self.__llm_runner__ is None:
|
||||
self.__llm_runner__ = runner(self)
|
||||
return self.__llm_runner__
|
||||
|
||||
def prepare(self, adapter_type='lora', use_gradient_checking=True, **attrs):
|
||||
if self.__llm_backend__ != 'pt': raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
|
||||
if self.__llm_backend__ != 'pt':
|
||||
raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
|
||||
from peft.mapping import get_peft_model
|
||||
from peft.utils.other import prepare_model_for_kbit_training
|
||||
|
||||
model = get_peft_model(
|
||||
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking),
|
||||
self.config['fine_tune_strategies']
|
||||
.get(adapter_type, self.config.make_fine_tune_config(adapter_type))
|
||||
.train().with_config(**attrs).build(),
|
||||
self.config['fine_tune_strategies'].get(adapter_type, self.config.make_fine_tune_config(adapter_type)).train().with_config(**attrs).build(),
|
||||
)
|
||||
if DEBUG: model.print_trainable_parameters()
|
||||
if DEBUG:
|
||||
model.print_trainable_parameters()
|
||||
return model, self.tokenizer
|
||||
def prepare_for_training(self, *args, **attrs): logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Use `prepare` instead.'); return self.prepare(*args, **attrs)
|
||||
|
||||
def prepare_for_training(self, *args, **attrs):
|
||||
logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Use `prepare` instead.')
|
||||
return self.prepare(*args, **attrs)
|
||||
|
||||
@property
|
||||
def adapter_map(self):
|
||||
if not is_peft_available(): raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
if not self.has_adapters: raise AttributeError('Adapter map is not available.')
|
||||
if not is_peft_available():
|
||||
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
if not self.has_adapters:
|
||||
raise AttributeError('Adapter map is not available.')
|
||||
assert self._adapter_map is not None
|
||||
if self.__llm_adapter_map__ is None:
|
||||
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
|
||||
for adapter_type, adapter_tuple in self._adapter_map.items():
|
||||
base = first_not_none(
|
||||
self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type),
|
||||
)
|
||||
for adapter in adapter_tuple: _map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
|
||||
base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type))
|
||||
for adapter in adapter_tuple:
|
||||
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
|
||||
self.__llm_adapter_map__ = _map
|
||||
return self.__llm_adapter_map__
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
if self.__llm_model__ is None:
|
||||
@@ -368,7 +464,10 @@ class LLM(t.Generic[M, T]):
|
||||
# If OOM, then it is probably you don't have enough VRAM to run this model.
|
||||
if self.__llm_backend__ == 'pt':
|
||||
import torch
|
||||
loaded_in_kbit = getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False)
|
||||
|
||||
loaded_in_kbit = (
|
||||
getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False)
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit:
|
||||
try:
|
||||
model = model.to('cuda')
|
||||
@@ -381,9 +480,11 @@ class LLM(t.Generic[M, T]):
|
||||
model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config)
|
||||
self.__llm_model__ = model
|
||||
return self.__llm_model__
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
import transformers
|
||||
|
||||
if self.__llm_config__ is None:
|
||||
if self.__llm_backend__ == 'ctranslate':
|
||||
try:
|
||||
@@ -397,26 +498,41 @@ class LLM(t.Generic[M, T]):
|
||||
).model_construct_env(**self._model_attrs)
|
||||
break
|
||||
else:
|
||||
raise OpenLLMException(f"Failed to infer the configuration class. Make sure the model is a supported model. Supported models are: {', '.join(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE.keys())}")
|
||||
raise OpenLLMException(
|
||||
f"Failed to infer the configuration class. Make sure the model is a supported model. Supported models are: {', '.join(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE.keys())}"
|
||||
)
|
||||
else:
|
||||
config = openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
|
||||
self.__llm_config__ = config
|
||||
return self.__llm_config__
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _torch_dtype_mapping() -> dict[str, torch.dtype]:
|
||||
import torch; return {
|
||||
'half': torch.float16, 'float16': torch.float16, #
|
||||
'float': torch.float32, 'float32': torch.float32, #
|
||||
import torch
|
||||
|
||||
return {
|
||||
'half': torch.float16,
|
||||
'float16': torch.float16, #
|
||||
'float': torch.float32,
|
||||
'float32': torch.float32, #
|
||||
'bfloat16': torch.bfloat16,
|
||||
}
|
||||
def normalise_model_name(name: str) -> str: return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/', '--'))
|
||||
|
||||
|
||||
def normalise_model_name(name: str) -> str:
|
||||
return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/', '--'))
|
||||
|
||||
|
||||
def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
if not is_peft_available(): raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
if not is_peft_available():
|
||||
raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
resolved: AdapterMap = {}
|
||||
for path_or_adapter_id, name in adapter_map.items():
|
||||
if name is None: raise ValueError('Adapter name must be specified.')
|
||||
if name is None:
|
||||
raise ValueError('Adapter name must be specified.')
|
||||
if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)):
|
||||
config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
else:
|
||||
@@ -424,8 +540,10 @@ def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
except Exception as err:
|
||||
raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") from err
|
||||
with open(config_file, 'r') as file: resolved_config = orjson.loads(file.read())
|
||||
with open(config_file, 'r') as file:
|
||||
resolved_config = orjson.loads(file.read())
|
||||
_peft_type = resolved_config['peft_type'].lower()
|
||||
if _peft_type not in resolved: resolved[_peft_type] = ()
|
||||
if _peft_type not in resolved:
|
||||
resolved[_peft_type] = ()
|
||||
resolved[_peft_type] += (_AdapterTuple((path_or_adapter_id, name, resolved_config)),)
|
||||
return resolved
|
||||
|
||||
@@ -8,16 +8,7 @@ from peft.peft_model import PeftModel, PeftModelForCausalLM, PeftModelForSeq2Seq
|
||||
from bentoml import Model, Tag
|
||||
from openllm_core import LLMConfig
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import (
|
||||
AdapterMap,
|
||||
AdapterType,
|
||||
LiteralBackend,
|
||||
LiteralDtype,
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
M,
|
||||
T,
|
||||
)
|
||||
from openllm_core._typing_compat import AdapterMap, AdapterType, LiteralBackend, LiteralDtype, LiteralQuantise, LiteralSerialisation, M, T
|
||||
|
||||
from ._quantisation import QuantizationConfig
|
||||
from ._runners import Runner
|
||||
@@ -121,9 +112,7 @@ class LLM(Generic[M, T]):
|
||||
def runner(self) -> Runner[M, T]: ...
|
||||
@property
|
||||
def adapter_map(self) -> ResolvedAdapterMap: ...
|
||||
def prepare(
|
||||
self, adapter_type: AdapterType = ..., use_gradient_checking: bool = ..., **attrs: Any
|
||||
) -> Tuple[InjectedModel, T]: ...
|
||||
def prepare(self, adapter_type: AdapterType = ..., use_gradient_checking: bool = ..., **attrs: Any) -> Tuple[InjectedModel, T]: ...
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.utils import is_autoawq_available, is_autogptq_available, is_bitsandbytes_available
|
||||
|
||||
|
||||
def infer_quantisation_config(llm, quantise, **attrs):
|
||||
import torch, transformers
|
||||
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop('llm_int8_threshhold', 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop('llm_int8_enable_fp32_cpu_offload', False)
|
||||
|
||||
@@ -9,18 +9,10 @@ from ._llm import LLM
|
||||
QuantizationConfig = Union[BitsAndBytesConfig, GPTQConfig, AwqConfig]
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(
|
||||
self: LLM[M, T], quantise: Literal['int8', 'int4'], **attrs: Any
|
||||
) -> tuple[BitsAndBytesConfig, Dict[str, Any]]: ...
|
||||
def infer_quantisation_config(self: LLM[M, T], quantise: Literal['int8', 'int4'], **attrs: Any) -> tuple[BitsAndBytesConfig, Dict[str, Any]]: ...
|
||||
@overload
|
||||
def infer_quantisation_config(
|
||||
self: LLM[M, T], quantise: Literal['gptq'], **attrs: Any
|
||||
) -> tuple[GPTQConfig, Dict[str, Any]]: ...
|
||||
def infer_quantisation_config(self: LLM[M, T], quantise: Literal['gptq'], **attrs: Any) -> tuple[GPTQConfig, Dict[str, Any]]: ...
|
||||
@overload
|
||||
def infer_quantisation_config(
|
||||
self: LLM[M, T], quantise: Literal['awq'], **attrs: Any
|
||||
) -> tuple[AwqConfig, Dict[str, Any]]: ...
|
||||
def infer_quantisation_config(self: LLM[M, T], quantise: Literal['awq'], **attrs: Any) -> tuple[AwqConfig, Dict[str, Any]]: ...
|
||||
@overload
|
||||
def infer_quantisation_config(
|
||||
self: LLM[M, T], quantise: LiteralQuantise, **attrs: Any
|
||||
) -> tuple[QuantizationConfig, Dict[str, Any]]: ...
|
||||
def infer_quantisation_config(self: LLM[M, T], quantise: LiteralQuantise, **attrs: Any) -> tuple[QuantizationConfig, Dict[str, Any]]: ...
|
||||
|
||||
@@ -11,13 +11,17 @@ if t.TYPE_CHECKING:
|
||||
_registry = {}
|
||||
__all__ = ['runner']
|
||||
|
||||
|
||||
def registry(cls=None, *, alias=None):
|
||||
def decorator(_cls):
|
||||
_registry[_cls.__name__[:-8].lower() if alias is None else alias] = _cls
|
||||
return _cls
|
||||
if cls is None: return decorator
|
||||
|
||||
if cls is None:
|
||||
return decorator
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]:
|
||||
try:
|
||||
assert llm.bentomodel
|
||||
@@ -25,32 +29,36 @@ def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]:
|
||||
raise RuntimeError(f'Failed to locate {llm.bentomodel}: {err}') from err
|
||||
|
||||
return types.new_class(
|
||||
llm.config.__class__.__name__[:-6] + 'Runner', (bentoml.Runner,), #
|
||||
exec_body=lambda ns: ns.update(
|
||||
{
|
||||
'llm_type': llm.llm_type, 'identifying_params': llm.identifying_params, #
|
||||
'llm_tag': llm.tag, 'llm': llm, 'config': llm.config, 'backend': llm.__llm_backend__, #
|
||||
'__module__': llm.__module__, '__repr__': ReprMixin.__repr__, #
|
||||
'__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}',
|
||||
'__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}),
|
||||
'__repr_args__': lambda _: (
|
||||
(
|
||||
'runner_methods',
|
||||
{
|
||||
method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None}
|
||||
for method in _.runner_methods
|
||||
},
|
||||
),
|
||||
('config', llm.config.model_dump(flatten=True)),
|
||||
('llm_type', llm.llm_type),
|
||||
('backend', llm.__llm_backend__),
|
||||
('llm_tag', llm.tag),
|
||||
llm.config.__class__.__name__[:-6] + 'Runner',
|
||||
(bentoml.Runner,), #
|
||||
exec_body=lambda ns: ns.update({
|
||||
'llm_type': llm.llm_type,
|
||||
'identifying_params': llm.identifying_params, #
|
||||
'llm_tag': llm.tag,
|
||||
'llm': llm,
|
||||
'config': llm.config,
|
||||
'backend': llm.__llm_backend__, #
|
||||
'__module__': llm.__module__,
|
||||
'__repr__': ReprMixin.__repr__, #
|
||||
'__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}',
|
||||
'__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}),
|
||||
'__repr_args__': lambda _: (
|
||||
(
|
||||
'runner_methods',
|
||||
{
|
||||
method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None}
|
||||
for method in _.runner_methods
|
||||
},
|
||||
),
|
||||
'has_adapters': llm.has_adapters,
|
||||
'template': llm.config.template,
|
||||
'system_message': llm.config.system_message,
|
||||
}
|
||||
),
|
||||
('config', llm.config.model_dump(flatten=True)),
|
||||
('llm_type', llm.llm_type),
|
||||
('backend', llm.__llm_backend__),
|
||||
('llm_tag', llm.tag),
|
||||
),
|
||||
'has_adapters': llm.has_adapters,
|
||||
'template': llm.config.template,
|
||||
'system_message': llm.config.system_message,
|
||||
}),
|
||||
)(
|
||||
_registry[llm.__llm_backend__],
|
||||
name=f"llm-{llm.config['start_name']}-runner",
|
||||
@@ -59,32 +67,44 @@ def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]:
|
||||
runnable_init_params={'llm': llm},
|
||||
)
|
||||
|
||||
|
||||
@registry
|
||||
class CTranslateRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
if not is_ctranslate_available(): raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
|
||||
if not is_ctranslate_available():
|
||||
raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
config, sampling_params = self.config.model_construct_env(stop=list(stop), **attrs).inference_options(self.llm)
|
||||
cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids)
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids)
|
||||
async for request_output in self.model.async_generate_tokens(tokens, **sampling_params):
|
||||
if config['logprobs']: cumulative_logprob += request_output.log_prob
|
||||
if config['logprobs']:
|
||||
cumulative_logprob += request_output.log_prob
|
||||
output_token_ids.append(request_output.token_id)
|
||||
text = self.tokenizer.decode(
|
||||
output_token_ids[input_len:], skip_special_tokens=True, #
|
||||
spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, #
|
||||
output_token_ids[input_len:],
|
||||
skip_special_tokens=True, #
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True, #
|
||||
)
|
||||
yield GenerationOutput(
|
||||
prompt_token_ids=prompt_token_ids, #
|
||||
prompt='', finished=request_output.is_last, request_id=request_id, #
|
||||
prompt_token_ids=prompt_token_ids, #
|
||||
prompt='',
|
||||
finished=request_output.is_last,
|
||||
request_id=request_id, #
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0, text=text, finish_reason=None, #
|
||||
token_ids=output_token_ids[input_len:], cumulative_logprob=cumulative_logprob, #
|
||||
index=0,
|
||||
text=text,
|
||||
finish_reason=None, #
|
||||
token_ids=output_token_ids[input_len:],
|
||||
cumulative_logprob=cumulative_logprob, #
|
||||
# TODO: logprobs, but seems like we don't have access to the raw logits
|
||||
)
|
||||
],
|
||||
@@ -95,30 +115,39 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
class vLLMRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
if not is_vllm_available(): raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
|
||||
if not is_vllm_available():
|
||||
raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
|
||||
import vllm
|
||||
|
||||
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2: num_gpus = min(dev // 2 * 2, dev)
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
try:
|
||||
self.model = vllm.AsyncLLMEngine.from_engine_args(
|
||||
vllm.AsyncEngineArgs(
|
||||
worker_use_ray=False, engine_use_ray=False, #
|
||||
tokenizer_mode='auto', tensor_parallel_size=num_gpus, #
|
||||
model=llm.bentomodel.path, tokenizer=llm.bentomodel.path, #
|
||||
trust_remote_code=llm.trust_remote_code, dtype=llm._torch_dtype, #
|
||||
max_model_len=llm._max_model_len, gpu_memory_utilization=llm._gpu_memory_utilization, #
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False, #
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus, #
|
||||
model=llm.bentomodel.path,
|
||||
tokenizer=llm.bentomodel.path, #
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
dtype=llm._torch_dtype, #
|
||||
max_model_len=llm._max_model_len,
|
||||
gpu_memory_utilization=llm._gpu_memory_utilization, #
|
||||
quantization=llm.quantise if llm.quantise and llm.quantise in {'awq', 'squeezellm'} else None,
|
||||
)
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
config, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm)
|
||||
_, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm)
|
||||
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
|
||||
yield GenerationOutput.from_vllm(request_output).model_dump_json()
|
||||
|
||||
@@ -127,6 +156,7 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
class PyTorchRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
self.is_encoder_decoder = llm.model.config.is_encoder_decoder
|
||||
@@ -134,10 +164,13 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
self.device = llm.model.device
|
||||
else:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
from ._generation import get_context_length, prepare_logits_processor
|
||||
if adapter_name is not None: self.model.set_adapter(adapter_name)
|
||||
|
||||
if adapter_name is not None:
|
||||
self.model.set_adapter(adapter_name)
|
||||
|
||||
max_new_tokens = attrs.pop('max_new_tokens', 256)
|
||||
context_length = attrs.pop('context_length', None)
|
||||
@@ -165,9 +198,7 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
if config['logprobs']: # FIXME: logprobs is not supported
|
||||
raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.')
|
||||
encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0]
|
||||
start_ids = torch.as_tensor(
|
||||
[[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device
|
||||
)
|
||||
start_ids = torch.as_tensor([[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device)
|
||||
else:
|
||||
start_ids = torch.as_tensor([prompt_token_ids], device=self.device)
|
||||
|
||||
@@ -195,9 +226,7 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
)
|
||||
logits = self.model.lm_head(out[0])
|
||||
else:
|
||||
out = self.model(
|
||||
input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True
|
||||
)
|
||||
out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
if logits_processor:
|
||||
@@ -241,12 +270,7 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
|
||||
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
|
||||
# XXX: Move this to API server
|
||||
text = self.tokenizer.decode(
|
||||
tmp_output_ids,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
text = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)
|
||||
|
||||
if len(stop) > 0:
|
||||
for it in stop:
|
||||
@@ -255,7 +279,8 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
text, stopped = text[:pos], True
|
||||
break
|
||||
|
||||
if config['logprobs']: sample_logprobs.append({token: token_logprobs})
|
||||
if config['logprobs']:
|
||||
sample_logprobs.append({token: token_logprobs})
|
||||
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
@@ -296,7 +321,7 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
|
||||
request_id=request_id,
|
||||
).model_dump_json()
|
||||
).model_dump_json()
|
||||
|
||||
# Clean
|
||||
del past_key_values, out
|
||||
|
||||
@@ -1,19 +1,4 @@
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
final,
|
||||
)
|
||||
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Protocol, Tuple, Type, TypeVar, Union, final
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
@@ -89,11 +74,7 @@ class Runner(Protocol[Mo, To]):
|
||||
class generate_iterator(RunnerMethod[List[int], AsyncGenerator[str, None]]):
|
||||
@staticmethod
|
||||
def async_stream(
|
||||
prompt_token_ids: List[int],
|
||||
request_id: str,
|
||||
stop: Optional[Union[Iterable[str], str]] = ...,
|
||||
adapter_name: Optional[str] = ...,
|
||||
**attrs: Any,
|
||||
prompt_token_ids: List[int], request_id: str, stop: Optional[Union[Iterable[str], str]] = ..., adapter_name: Optional[str] = ..., **attrs: Any
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -6,38 +6,48 @@ from bentoml.io import JSON, Text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=svars.model_id, model_tag=svars.model_tag, adapter_map=svars.adapter_map, #
|
||||
serialisation=svars.serialization, trust_remote_code=svars.trust_remote_code, #
|
||||
max_model_len=svars.max_model_len, gpu_memory_utilization=svars.gpu_memory_utilization, #
|
||||
model_id=svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
adapter_map=svars.adapter_map, #
|
||||
serialisation=svars.serialization,
|
||||
trust_remote_code=svars.trust_remote_code, #
|
||||
max_model_len=svars.max_model_len,
|
||||
gpu_memory_utilization=svars.gpu_memory_utilization, #
|
||||
)
|
||||
svc = bentoml.Service(name=f"llm-{llm.config['start_name']}-service", runners=[llm.runner])
|
||||
llm_model_class = openllm.GenerationInput.from_llm_config(llm.config)
|
||||
|
||||
@svc.api(
|
||||
route='/v1/generate',
|
||||
input=JSON.from_sample(llm_model_class.examples()),
|
||||
output=JSON.from_sample(openllm.GenerationOutput.examples()),
|
||||
)
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> dict[str, t.Any]: return (await llm.generate(**llm_model_class(**input_dict).model_dump())).model_dump()
|
||||
|
||||
@svc.api(
|
||||
route='/v1/generate_stream',
|
||||
input=JSON.from_sample(llm_model_class.examples()),
|
||||
output=Text(content_type='text/event-stream'),
|
||||
)
|
||||
@svc.api(route='/v1/generate', input=JSON.from_sample(llm_model_class.examples()), output=JSON.from_sample(openllm.GenerationOutput.examples()))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> dict[str, t.Any]:
|
||||
return (await llm.generate(**llm_model_class(**input_dict).model_dump())).model_dump()
|
||||
|
||||
|
||||
@svc.api(route='/v1/generate_stream', input=JSON.from_sample(llm_model_class.examples()), output=Text(content_type='text/event-stream'))
|
||||
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
||||
async for it in llm.generate_iterator(**llm_model_class(**input_dict).model_dump()): yield f'data: {it.model_dump_json()}\n\n'
|
||||
async for it in llm.generate_iterator(**llm_model_class(**input_dict).model_dump()):
|
||||
yield f'data: {it.model_dump_json()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
|
||||
_Metadata = openllm.MetadataOutput(
|
||||
timeout=llm.config['timeout'], model_name=llm.config['model_name'], #
|
||||
backend=llm.__llm_backend__, model_id=llm.model_id, configuration=llm.config.model_dump_json().decode(), #
|
||||
timeout=llm.config['timeout'],
|
||||
model_name=llm.config['model_name'], #
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm.config.model_dump_json().decode(), #
|
||||
)
|
||||
|
||||
@svc.api(route='/v1/metadata', input=Text(), output=JSON.from_sample(_Metadata.model_dump()))
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput: return _Metadata
|
||||
|
||||
class MessagesConverterInput(t.TypedDict): add_generation_prompt: bool; messages: t.List[t.Dict[str, t.Any]]
|
||||
@svc.api(route='/v1/metadata', input=Text(), output=JSON.from_sample(_Metadata.model_dump()))
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return _Metadata
|
||||
|
||||
|
||||
class MessagesConverterInput(t.TypedDict):
|
||||
add_generation_prompt: bool
|
||||
messages: t.List[t.Dict[str, t.Any]]
|
||||
|
||||
|
||||
@svc.api(
|
||||
route='/v1/helpers/messages',
|
||||
@@ -46,7 +56,8 @@ class MessagesConverterInput(t.TypedDict): add_generation_prompt: bool; messages
|
||||
add_generation_prompt=False,
|
||||
messages=[
|
||||
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
MessageParam(role='user', content='Hi there!'), MessageParam(role='assistant', content='Yes?'), #
|
||||
MessageParam(role='user', content='Hi there!'),
|
||||
MessageParam(role='assistant', content='Yes?'), #
|
||||
],
|
||||
)
|
||||
),
|
||||
@@ -56,4 +67,5 @@ def helpers_messages_v1(message: MessagesConverterInput) -> str:
|
||||
add_generation_prompt, messages = message['add_generation_prompt'], message['messages']
|
||||
return llm.tokenizer.apply_chat_template(messages, add_generation_prompt=add_generation_prompt, tokenize=False)
|
||||
|
||||
openllm.mount_entrypoints(svc, llm) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema.
|
||||
|
||||
openllm.mount_entrypoints(svc, llm) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema.
|
||||
|
||||
@@ -1,3 +1,13 @@
|
||||
import os, orjson, openllm_core.utils as coreutils
|
||||
model_id, model_tag, adapter_map, serialization, trust_remote_code = os.environ['OPENLLM_MODEL_ID'], None, orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None))), os.getenv('OPENLLM_SERIALIZATION', default='safetensors'), coreutils.check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
max_model_len, gpu_memory_utilization = orjson.loads(os.getenv('MAX_MODEL_LEN', orjson.dumps(None).decode())), orjson.loads(os.getenv('GPU_MEMORY_UTILIZATION', orjson.dumps(0.9).decode()))
|
||||
|
||||
model_id, model_tag, adapter_map, serialization, trust_remote_code = (
|
||||
os.environ['OPENLLM_MODEL_ID'],
|
||||
None,
|
||||
orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None))),
|
||||
os.getenv('OPENLLM_SERIALIZATION', default='safetensors'),
|
||||
coreutils.check_bool_env('TRUST_REMOTE_CODE', False),
|
||||
)
|
||||
max_model_len, gpu_memory_utilization = (
|
||||
orjson.loads(os.getenv('MAX_MODEL_LEN', orjson.dumps(None).decode())),
|
||||
orjson.loads(os.getenv('GPU_MEMORY_UTILIZATION', orjson.dumps(0.9).decode())),
|
||||
)
|
||||
|
||||
@@ -7,30 +7,42 @@ from bentoml._internal.runner.strategy import THREAD_ENVS
|
||||
__all__ = ['CascadingResourceStrategy', 'get_resource']
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _strtoul(s: str) -> int:
|
||||
# Return -1 or positive integer sequence string starts with.
|
||||
if not s: return -1
|
||||
if not s:
|
||||
return -1
|
||||
idx = 0
|
||||
for idx, c in enumerate(s):
|
||||
if not (c.isdigit() or (idx == 0 and c in '+-')): break
|
||||
if idx + 1 == len(s): idx += 1 # noqa: PLW2901
|
||||
if not (c.isdigit() or (idx == 0 and c in '+-')):
|
||||
break
|
||||
if idx + 1 == len(s):
|
||||
idx += 1
|
||||
# NOTE: idx will be set via enumerate
|
||||
return int(s[:idx]) if idx > 0 else -1
|
||||
|
||||
|
||||
def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
|
||||
rcs = []
|
||||
for elem in lst.split(','):
|
||||
# Repeated id results in empty set
|
||||
if elem in rcs: return []
|
||||
if elem in rcs:
|
||||
return []
|
||||
# Anything other but prefix is ignored
|
||||
if not elem.startswith(prefix): break
|
||||
if not elem.startswith(prefix):
|
||||
break
|
||||
rcs.append(elem)
|
||||
return rcs
|
||||
|
||||
|
||||
def _parse_cuda_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None:
|
||||
if respect_env:
|
||||
spec = os.environ.get('CUDA_VISIBLE_DEVICES', default_var)
|
||||
if not spec: return None
|
||||
if not spec:
|
||||
return None
|
||||
else:
|
||||
if default_var is None: raise ValueError('spec is required to be not None when parsing spec.')
|
||||
if default_var is None:
|
||||
raise ValueError('spec is required to be not None when parsing spec.')
|
||||
spec = default_var
|
||||
|
||||
if spec.startswith('GPU-'):
|
||||
@@ -44,48 +56,59 @@ def _parse_cuda_visible_devices(default_var: str | None = None, respect_env: boo
|
||||
for el in spec.split(','):
|
||||
x = _strtoul(el.strip())
|
||||
# Repeated ordinal results in empty set
|
||||
if x in rc: return []
|
||||
if x in rc:
|
||||
return []
|
||||
# Negative value aborts the sequence
|
||||
if x < 0: break
|
||||
if x < 0:
|
||||
break
|
||||
rc.append(x)
|
||||
return [str(i) for i in rc]
|
||||
|
||||
|
||||
def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
from ctypes import CDLL, byref, c_int, c_void_p, create_string_buffer
|
||||
|
||||
try:
|
||||
nvml_h = CDLL('libnvidia-ml.so.1')
|
||||
except Exception:
|
||||
warnings.warn('Failed to find nvidia binding', stacklevel=3); return None
|
||||
warnings.warn('Failed to find nvidia binding', stacklevel=3)
|
||||
return None
|
||||
|
||||
rc = nvml_h.nvmlInit()
|
||||
if rc != 0:
|
||||
warnings.warn("Can't initialize NVML", stacklevel=3); return None
|
||||
warnings.warn("Can't initialize NVML", stacklevel=3)
|
||||
return None
|
||||
dev_count = c_int(-1)
|
||||
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
||||
if rc != 0:
|
||||
warnings.warn('Failed to get available device from system.', stacklevel=3); return None
|
||||
warnings.warn('Failed to get available device from system.', stacklevel=3)
|
||||
return None
|
||||
uuids = []
|
||||
for idx in range(dev_count.value):
|
||||
dev_id = c_void_p()
|
||||
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
|
||||
if rc != 0:
|
||||
warnings.warn(f'Failed to get device handle for {idx}', stacklevel=3); return None
|
||||
warnings.warn(f'Failed to get device handle for {idx}', stacklevel=3)
|
||||
return None
|
||||
buf_len = 96
|
||||
buf = create_string_buffer(buf_len)
|
||||
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
|
||||
if rc != 0:
|
||||
warnings.warn(f'Failed to get device UUID for {idx}', stacklevel=3); return None
|
||||
warnings.warn(f'Failed to get device UUID for {idx}', stacklevel=3)
|
||||
return None
|
||||
uuids.append(buf.raw.decode('ascii').strip('\0'))
|
||||
del nvml_h
|
||||
return uuids
|
||||
|
||||
|
||||
class _ResourceMixin:
|
||||
@staticmethod
|
||||
def from_system(cls) -> list[str]:
|
||||
visible_devices = _parse_cuda_visible_devices()
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
if not psutil.LINUX: return []
|
||||
if not psutil.LINUX:
|
||||
return []
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
# So we need to use the ctypes bindings directly.
|
||||
# we don't want to use CLI because parsing is a pain.
|
||||
@@ -99,7 +122,8 @@ class _ResourceMixin:
|
||||
|
||||
device_count = c_uint32(0)
|
||||
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)]
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS:
|
||||
return [str(i) for i in range(device_count.value)]
|
||||
return []
|
||||
# In this case the binary is not found, returning empty list
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
@@ -116,20 +140,26 @@ class _ResourceMixin:
|
||||
except (ImportError, RuntimeError, AttributeError):
|
||||
return []
|
||||
return visible_devices
|
||||
|
||||
@staticmethod
|
||||
def from_spec(cls, spec) -> list[str]:
|
||||
if isinstance(spec, int):
|
||||
if spec in (-1, 0): return []
|
||||
if spec < -1: raise ValueError('Spec cannot be < -1.')
|
||||
if spec in (-1, 0):
|
||||
return []
|
||||
if spec < -1:
|
||||
raise ValueError('Spec cannot be < -1.')
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
if not spec: return []
|
||||
if spec.isdigit(): spec = ','.join([str(i) for i in range(_strtoul(spec))])
|
||||
if not spec:
|
||||
return []
|
||||
if spec.isdigit():
|
||||
spec = ','.join([str(i) for i in range(_strtoul(spec))])
|
||||
return _parse_cuda_visible_devices(spec, respect_env=False)
|
||||
elif isinstance(spec, list):
|
||||
return [str(x) for x in spec]
|
||||
else:
|
||||
raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
|
||||
|
||||
@staticmethod
|
||||
def validate(cls, val: list[t.Any]) -> None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
@@ -139,83 +169,102 @@ class _ResourceMixin:
|
||||
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
# correctly parse handle
|
||||
for el in val:
|
||||
if el.startswith(('GPU-', 'MIG-')):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None: raise ValueError('Failed to parse available GPUs UUID')
|
||||
if el not in uuids: raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})')
|
||||
if uuids is None:
|
||||
raise ValueError('Failed to parse available GPUs UUID')
|
||||
if el not in uuids:
|
||||
raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})')
|
||||
elif el.isdigit():
|
||||
err, _ = cuda.cuDeviceGet(int(el))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f'Failed to get device {el}')
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise ValueError(f'Failed to get device {el}')
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[bentoml.Resource[t.List[str]]]:
|
||||
return types.new_class(
|
||||
name,
|
||||
(bentoml.Resource[t.List[str]], coreutils.ReprMixin),
|
||||
{'resource_id': resource_kind},
|
||||
lambda ns: ns.update(
|
||||
{
|
||||
'resource_id': resource_kind,
|
||||
'from_spec': classmethod(_ResourceMixin.from_spec), 'from_system': classmethod(_ResourceMixin.from_system), #
|
||||
'validate': classmethod(_ResourceMixin.validate), '__repr_keys__': property(lambda _: {'resource_id'}), #
|
||||
'__doc__': inspect.cleandoc(docstring), '__module__': 'openllm._strategies', #
|
||||
}
|
||||
),
|
||||
lambda ns: ns.update({
|
||||
'resource_id': resource_kind,
|
||||
'from_spec': classmethod(_ResourceMixin.from_spec),
|
||||
'from_system': classmethod(_ResourceMixin.from_system), #
|
||||
'validate': classmethod(_ResourceMixin.validate),
|
||||
'__repr_keys__': property(lambda _: {'resource_id'}), #
|
||||
'__doc__': inspect.cleandoc(docstring),
|
||||
'__module__': 'openllm._strategies', #
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
NvidiaGpuResource = _make_resource_class(
|
||||
'NvidiaGpuResource',
|
||||
'nvidia.com/gpu',
|
||||
'''NVIDIA GPU resource.
|
||||
"""NVIDIA GPU resource.
|
||||
This is a modified version of internal's BentoML's NvidiaGpuResource
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.''',
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.""",
|
||||
)
|
||||
AmdGpuResource = _make_resource_class(
|
||||
'AmdGpuResource',
|
||||
'amd.com/gpu',
|
||||
'''AMD GPU resource.
|
||||
"""AMD GPU resource.
|
||||
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.''',
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""",
|
||||
)
|
||||
|
||||
|
||||
class CascadingResourceStrategy(bentoml.Strategy, coreutils.ReprMixin):
|
||||
@classmethod
|
||||
def get_worker_count(cls, runnable_class, resource_request, workers_per_resource):
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
if resource_request is None:
|
||||
resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = 'nvidia.com/gpu'
|
||||
nvidia_req = get_resource(resource_request, kind)
|
||||
if nvidia_req is not None: return 1
|
||||
if nvidia_req is not None:
|
||||
return 1
|
||||
# use AMD
|
||||
kind = 'amd.com/gpu'
|
||||
amd_req = get_resource(resource_request, kind, validate=False)
|
||||
if amd_req is not None: return 1
|
||||
if amd_req is not None:
|
||||
return 1
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, 'cpu')
|
||||
if cpus is not None and cpus > 0:
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError('Fractional CPU multi threading support is not yet supported.')
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0:
|
||||
raise ValueError('Fractional CPU multi threading support is not yet supported.')
|
||||
return int(workers_per_resource)
|
||||
return math.ceil(cpus) * workers_per_resource
|
||||
# this should not be reached by user since we always read system resource as default
|
||||
raise ValueError(f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.')
|
||||
raise ValueError(
|
||||
f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_worker_env(cls, runnable_class, resource_request, workers_per_resource, worker_index):
|
||||
cuda_env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
|
||||
disabled = cuda_env in ('', '-1')
|
||||
environ = {}
|
||||
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
if resource_request is None:
|
||||
resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = 'nvidia.com/gpu'
|
||||
typ = get_resource(resource_request, kind)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cuda_env; return environ
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cuda_env
|
||||
return environ
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
return environ
|
||||
# use AMD
|
||||
@@ -223,7 +272,8 @@ class CascadingResourceStrategy(bentoml.Strategy, coreutils.ReprMixin):
|
||||
typ = get_resource(resource_request, kind, validate=False)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cuda_env; return environ
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cuda_env
|
||||
return environ
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
return environ
|
||||
# use CPU
|
||||
@@ -232,17 +282,21 @@ class CascadingResourceStrategy(bentoml.Strategy, coreutils.ReprMixin):
|
||||
environ['CUDA_VISIBLE_DEVICES'] = '-1' # disable gpu
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
thread_count = math.ceil(cpus)
|
||||
for thread_env in THREAD_ENVS: environ[thread_env] = os.environ.get(thread_env, str(thread_count))
|
||||
for thread_env in THREAD_ENVS:
|
||||
environ[thread_env] = os.environ.get(thread_env, str(thread_count))
|
||||
return environ
|
||||
for thread_env in THREAD_ENVS: environ[thread_env] = os.environ.get(thread_env, '1')
|
||||
for thread_env in THREAD_ENVS:
|
||||
environ[thread_env] = os.environ.get(thread_env, '1')
|
||||
return environ
|
||||
return environ
|
||||
|
||||
@staticmethod
|
||||
def transpile_workers_to_cuda_envvar(workers_per_resource, gpus, worker_index):
|
||||
# Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.
|
||||
if isinstance(workers_per_resource, float):
|
||||
# NOTE: We hit this branch when workers_per_resource is set to float, for example 0.5 or 0.25
|
||||
if workers_per_resource > 1: raise ValueError('workers_per_resource > 1 is not supported.')
|
||||
if workers_per_resource > 1:
|
||||
raise ValueError('workers_per_resource > 1 is not supported.')
|
||||
# We are round the assigned resource here. This means if workers_per_resource=.4
|
||||
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
|
||||
assigned_resource_per_worker = round(1 / workers_per_resource)
|
||||
|
||||
@@ -13,16 +13,11 @@ class CascadingResourceStrategy:
|
||||
TODO: Support CloudTPUResource
|
||||
"""
|
||||
@classmethod
|
||||
def get_worker_count(
|
||||
cls,
|
||||
runnable_class: Type[bentoml.Runnable],
|
||||
resource_request: Optional[Dict[str, Any]],
|
||||
workers_per_resource: float,
|
||||
) -> int:
|
||||
'''Return the number of workers to be used for the given runnable class.
|
||||
def get_worker_count(cls, runnable_class: Type[bentoml.Runnable], resource_request: Optional[Dict[str, Any]], workers_per_resource: float) -> int:
|
||||
"""Return the number of workers to be used for the given runnable class.
|
||||
|
||||
Note that for all available GPU, the number of workers will always be 1.
|
||||
'''
|
||||
"""
|
||||
@classmethod
|
||||
def get_worker_env(
|
||||
cls,
|
||||
@@ -31,16 +26,14 @@ class CascadingResourceStrategy:
|
||||
workers_per_resource: Union[int, float],
|
||||
worker_index: int,
|
||||
) -> Dict[str, Any]:
|
||||
'''Get worker env for this given worker_index.
|
||||
"""Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
runnable_class: The runnable class to be run.
|
||||
resource_request: The resource request of the runnable.
|
||||
workers_per_resource: # of workers per resource.
|
||||
worker_index: The index of the worker, start from 0.
|
||||
'''
|
||||
"""
|
||||
@staticmethod
|
||||
def transpile_workers_to_cuda_envvar(
|
||||
workers_per_resource: Union[float, int], gpus: List[str], worker_index: int
|
||||
) -> str:
|
||||
'''Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.'''
|
||||
def transpile_workers_to_cuda_envvar(workers_per_resource: Union[float, int], gpus: List[str], worker_index: int) -> str:
|
||||
"""Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string."""
|
||||
|
||||
@@ -4,11 +4,13 @@ from openllm_core._typing_compat import LiteralVersionStrategy
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils.lazy import VersionInfo, LazyModule
|
||||
|
||||
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
|
||||
class RefResolver:
|
||||
git_hash: str = attr.field()
|
||||
version: VersionInfo = attr.field(converter=lambda s: VersionInfo.from_version_string(s))
|
||||
strategy: LiteralVersionStrategy = attr.field()
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def from_strategy(cls, strategy_or_version: LiteralVersionStrategy | None = None) -> RefResolver:
|
||||
@@ -16,6 +18,7 @@ class RefResolver:
|
||||
if strategy_or_version is None or strategy_or_version == 'release':
|
||||
try:
|
||||
from ghapi.all import GhApi
|
||||
|
||||
ghapi = GhApi(owner='bentoml', repo='openllm', authenticate=False)
|
||||
meta = ghapi.repos.get_latest_release()
|
||||
git_hash = ghapi.git.get_ref(ref=f"tags/{meta['name']}")['object']['sha']
|
||||
@@ -26,11 +29,16 @@ class RefResolver:
|
||||
return cls('latest', '0.0.0', 'latest')
|
||||
else:
|
||||
raise ValueError(f'Unknown strategy: {strategy_or_version}')
|
||||
|
||||
@property
|
||||
def tag(self) -> str: return 'latest' if self.strategy in {'latest', 'nightly'} else repr(self.version)
|
||||
def tag(self) -> str:
|
||||
return 'latest' if self.strategy in {'latest', 'nightly'} else repr(self.version)
|
||||
|
||||
|
||||
__lazy = LazyModule(
|
||||
__name__, os.path.abspath('__file__'), #
|
||||
__name__,
|
||||
os.path.abspath('__file__'), #
|
||||
{'_package': ['create_bento', 'build_editable', 'construct_python_options', 'construct_docker_options']},
|
||||
extra_objects={'RefResolver': RefResolver}
|
||||
extra_objects={'RefResolver': RefResolver},
|
||||
)
|
||||
__all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__
|
||||
|
||||
@@ -12,14 +12,18 @@ OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD'
|
||||
_service_file = pathlib.Path(os.path.abspath(__file__)).parent.parent / '_service.py'
|
||||
_SERVICE_VARS = '''import orjson;model_id,model_tag,adapter_map,serialization,trust_remote_code,max_model_len,gpu_memory_utilization='{__model_id__}','{__model_tag__}',orjson.loads("""{__model_adapter_map__}"""),'{__model_serialization__}',{__model_trust_remote_code__},{__max_model_len__},{__gpu_memory_utilization__}'''
|
||||
|
||||
|
||||
def build_editable(path, package='openllm'):
|
||||
if not check_bool_env(OPENLLM_DEV_BUILD, default=False): return None
|
||||
if not check_bool_env(OPENLLM_DEV_BUILD, default=False):
|
||||
return None
|
||||
# We need to build the package in editable mode, so that we can import it
|
||||
# TODO: Upgrade to 1.0.3
|
||||
from build import ProjectBuilder
|
||||
from build.env import IsolatedEnvBuilder
|
||||
|
||||
module_location = pkg.source_locations(package)
|
||||
if not module_location: raise RuntimeError('Could not find the source location of OpenLLM.')
|
||||
if not module_location:
|
||||
raise RuntimeError('Could not find the source location of OpenLLM.')
|
||||
pyproject_path = pathlib.Path(module_location).parent.parent / 'pyproject.toml'
|
||||
if os.path.isfile(pyproject_path.__fspath__()):
|
||||
with IsolatedEnvBuilder() as env:
|
||||
@@ -29,56 +33,79 @@ def build_editable(path, package='openllm'):
|
||||
env.install(builder.build_system_requires)
|
||||
return builder.build('wheel', path, config_settings={'--global-option': '--quiet'})
|
||||
raise RuntimeError('Please install OpenLLM from PyPI or built it from Git source.')
|
||||
|
||||
|
||||
def construct_python_options(llm, llm_fs, extra_dependencies=None, adapter_map=None):
|
||||
from . import RefResolver
|
||||
|
||||
packages = ['scipy', 'bentoml[tracing]>=1.1.10', f'openllm[vllm]>={RefResolver.from_strategy("release").version}'] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ['openllm[fine-tune]']
|
||||
if extra_dependencies is not None: packages += [f'openllm[{k}]' for k in extra_dependencies]
|
||||
if llm.config['requirements'] is not None: packages.extend(llm.config['requirements'])
|
||||
if adapter_map is not None:
|
||||
packages += ['openllm[fine-tune]']
|
||||
if extra_dependencies is not None:
|
||||
packages += [f'openllm[{k}]' for k in extra_dependencies]
|
||||
if llm.config['requirements'] is not None:
|
||||
packages.extend(llm.config['requirements'])
|
||||
built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')]
|
||||
return PythonOptions(packages=packages, wheels=[llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels] if all(i for i in built_wheels) else None, lock_packages=False)
|
||||
return PythonOptions(
|
||||
packages=packages,
|
||||
wheels=[llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels] if all(i for i in built_wheels) else None,
|
||||
lock_packages=False,
|
||||
)
|
||||
|
||||
|
||||
def construct_docker_options(llm, _, quantize, adapter_map, dockerfile_template, serialisation):
|
||||
from openllm_cli.entrypoint import process_environ
|
||||
environ = process_environ(
|
||||
llm.config, llm.config['timeout'],
|
||||
1.0, None, True,
|
||||
llm.model_id, None,
|
||||
llm._serialisation, llm,
|
||||
use_current_env=False,
|
||||
)
|
||||
|
||||
environ = process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm, use_current_env=False)
|
||||
# XXX: We need to quote this so that the envvar in container recognize as valid json
|
||||
environ['OPENLLM_CONFIG'] = f"'{environ['OPENLLM_CONFIG']}'"
|
||||
environ.pop('BENTOML_HOME', None) # NOTE: irrelevant in container
|
||||
environ['NVIDIA_DRIVER_CAPABILITIES'] = 'compute,utility'
|
||||
return DockerOptions(cuda_version='12.1', python_version='3.11', env=environ, dockerfile_template=dockerfile_template)
|
||||
|
||||
|
||||
@inject
|
||||
def create_bento(
|
||||
bento_tag, llm_fs, llm, #
|
||||
quantize, dockerfile_template, #
|
||||
adapter_map=None, extra_dependencies=None, serialisation=None, #
|
||||
_bento_store=Provide[BentoMLContainer.bento_store], _model_store=Provide[BentoMLContainer.model_store],
|
||||
bento_tag,
|
||||
llm_fs,
|
||||
llm, #
|
||||
quantize,
|
||||
dockerfile_template, #
|
||||
adapter_map=None,
|
||||
extra_dependencies=None,
|
||||
serialisation=None, #
|
||||
_bento_store=Provide[BentoMLContainer.bento_store],
|
||||
_model_store=Provide[BentoMLContainer.model_store],
|
||||
):
|
||||
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm.config['serialisation'])
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update(
|
||||
{
|
||||
'_type': llm.llm_type, '_framework': llm.__llm_backend__,
|
||||
'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle',
|
||||
**{f'{package.replace("-","_")}_version': importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}},
|
||||
}
|
||||
)
|
||||
if adapter_map: labels.update(adapter_map)
|
||||
labels.update({
|
||||
'_type': llm.llm_type,
|
||||
'_framework': llm.__llm_backend__,
|
||||
'start_name': llm.config['start_name'],
|
||||
'base_name_or_path': llm.model_id,
|
||||
'bundler': 'openllm.bundle',
|
||||
**{f'{package.replace("-","_")}_version': importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}},
|
||||
})
|
||||
if adapter_map:
|
||||
labels.update(adapter_map)
|
||||
|
||||
logger.debug("Building Bento '%s' with model backend '%s'", bento_tag, llm.__llm_backend__)
|
||||
logger.debug('Generating service vars %s (dir=%s)', llm.model_id, llm_fs.getsyspath('/'))
|
||||
script = f"# fmt: off\n# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n" + _SERVICE_VARS.format(
|
||||
__model_id__=llm.model_id, __model_tag__=str(llm.tag), #
|
||||
__model_adapter_map__=orjson.dumps(adapter_map).decode(), __model_serialization__=llm.config['serialisation'], #
|
||||
__model_trust_remote_code__=str(llm.trust_remote_code), __max_model_len__ = llm._max_model_len, __gpu_memory_utilization__=llm._gpu_memory_utilization, #
|
||||
__model_id__=llm.model_id,
|
||||
__model_tag__=str(llm.tag), #
|
||||
__model_adapter_map__=orjson.dumps(adapter_map).decode(),
|
||||
__model_serialization__=llm.config['serialisation'], #
|
||||
__model_trust_remote_code__=str(llm.trust_remote_code),
|
||||
__max_model_len__=llm._max_model_len,
|
||||
__gpu_memory_utilization__=llm._gpu_memory_utilization, #
|
||||
)
|
||||
if SHOW_CODEGEN: logger.info('Generated _service_vars.py:\n%s', script)
|
||||
if SHOW_CODEGEN:
|
||||
logger.info('Generated _service_vars.py:\n%s', script)
|
||||
llm_fs.writetext('_service_vars.py', script)
|
||||
with open(_service_file.__fspath__(), 'r') as f: service_src = f.read()
|
||||
with open(_service_file.__fspath__(), 'r') as f:
|
||||
service_src = f.read()
|
||||
llm_fs.writetext(llm.config['service_name'], service_src)
|
||||
return bentoml.Bento.create(
|
||||
version=bento_tag.version,
|
||||
|
||||
@@ -7,21 +7,13 @@ from bentoml import Bento, Tag
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.bento.build_config import DockerOptions, PythonOptions
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
from openllm_core._typing_compat import (
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
M,
|
||||
T,
|
||||
)
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, M, T
|
||||
|
||||
from .._llm import LLM
|
||||
|
||||
def build_editable(path: str, package: LiteralString) -> Optional[str]: ...
|
||||
def construct_python_options(
|
||||
llm: LLM[M, T],
|
||||
llm_fs: FS,
|
||||
extra_dependencies: Optional[Tuple[str, ...]] = ...,
|
||||
adapter_map: Optional[Dict[str, str]] = ...,
|
||||
llm: LLM[M, T], llm_fs: FS, extra_dependencies: Optional[Tuple[str, ...]] = ..., adapter_map: Optional[Dict[str, str]] = ...
|
||||
) -> PythonOptions: ...
|
||||
def construct_docker_options(
|
||||
llm: LLM[M, T],
|
||||
|
||||
@@ -1,2 +1,10 @@
|
||||
def __dir__(): import openllm_client as _client; return sorted(dir(_client))
|
||||
def __getattr__(it): import openllm_client as _client; return getattr(_client, it)
|
||||
def __dir__():
|
||||
import openllm_client as _client
|
||||
|
||||
return sorted(dir(_client))
|
||||
|
||||
|
||||
def __getattr__(it):
|
||||
import openllm_client as _client
|
||||
|
||||
return getattr(_client, it)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
'''OpenLLM Python client.
|
||||
"""OpenLLM Python client.
|
||||
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
client.query("What is the difference between gather and scatter?")
|
||||
```
|
||||
'''
|
||||
"""
|
||||
|
||||
from openllm_client import AsyncHTTPClient as AsyncHTTPClient, HTTPClient as HTTPClient
|
||||
|
||||
@@ -2,10 +2,14 @@ import importlib
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
_import_structure = {'openai': [], 'hf': [], 'cohere': []}
|
||||
|
||||
|
||||
def mount_entrypoints(svc, llm):
|
||||
for module_name in _import_structure:
|
||||
module = importlib.import_module(f'.{module_name}', __name__)
|
||||
svc = module.mount_to_svc(svc, llm)
|
||||
return svc
|
||||
|
||||
|
||||
__lazy = LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints})
|
||||
__all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
'''Entrypoint for all third-party apps.
|
||||
"""Entrypoint for all third-party apps.
|
||||
|
||||
Currently support OpenAI, Cohere compatible API.
|
||||
|
||||
Each module should implement the following API:
|
||||
|
||||
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
|
||||
'''
|
||||
"""
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
@@ -11,7 +11,7 @@ from openllm_core.utils import first_not_none
|
||||
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODELS_SCHEMA = '''\
|
||||
LIST_MODELS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -41,8 +41,8 @@ responses:
|
||||
owned_by: 'na'
|
||||
schema:
|
||||
$ref: '#/components/schemas/ModelList'
|
||||
'''
|
||||
CHAT_COMPLETIONS_SCHEMA = '''\
|
||||
"""
|
||||
CHAT_COMPLETIONS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -181,8 +181,8 @@ responses:
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
'''
|
||||
COMPLETIONS_SCHEMA = '''\
|
||||
"""
|
||||
COMPLETIONS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -334,8 +334,8 @@ responses:
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
'''
|
||||
HF_AGENT_SCHEMA = '''\
|
||||
"""
|
||||
HF_AGENT_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -379,8 +379,8 @@ responses:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
'''
|
||||
HF_ADAPTERS_SCHEMA = '''\
|
||||
"""
|
||||
HF_ADAPTERS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -410,8 +410,8 @@ responses:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
'''
|
||||
COHERE_GENERATE_SCHEMA = '''\
|
||||
"""
|
||||
COHERE_GENERATE_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -455,8 +455,8 @@ requestBody:
|
||||
stop_sequences:
|
||||
- "\\n"
|
||||
- "<|endoftext|>"
|
||||
'''
|
||||
COHERE_CHAT_SCHEMA = '''\
|
||||
"""
|
||||
COHERE_CHAT_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -469,7 +469,7 @@ tags:
|
||||
- Cohere
|
||||
x-bentoml-name: cohere_chat
|
||||
summary: Creates a model response for the given chat conversation.
|
||||
'''
|
||||
"""
|
||||
|
||||
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
|
||||
|
||||
@@ -504,11 +504,7 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
endpoints_info.extend(sub_endpoints)
|
||||
elif not isinstance(route, Route) or not route.include_in_schema:
|
||||
continue
|
||||
elif (
|
||||
inspect.isfunction(route.endpoint)
|
||||
or inspect.ismethod(route.endpoint)
|
||||
or isinstance(route.endpoint, functools.partial)
|
||||
):
|
||||
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint) or isinstance(route.endpoint, functools.partial):
|
||||
endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint
|
||||
path = self._remove_converter(route.path)
|
||||
for method in route.methods or ['GET']:
|
||||
@@ -555,22 +551,20 @@ def get_generator(title, components=None, tags=None, inject=True):
|
||||
|
||||
def component_schema_generator(attr_cls, description=None):
|
||||
schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
schema['description'] = first_not_none(
|
||||
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
|
||||
)
|
||||
schema['description'] = first_not_none(getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}')
|
||||
for field in attr.fields(attr.resolve_types(attr_cls)):
|
||||
attr_type = field.type
|
||||
origin_type = t.get_origin(attr_type)
|
||||
args_type = t.get_args(attr_type)
|
||||
|
||||
# Map Python types to OpenAPI schema types
|
||||
if attr_type == str:
|
||||
if isinstance(attr_type, str):
|
||||
schema_type = 'string'
|
||||
elif attr_type == int:
|
||||
elif isinstance(attr_type, int):
|
||||
schema_type = 'integer'
|
||||
elif attr_type == float:
|
||||
elif isinstance(attr_type, float):
|
||||
schema_type = 'number'
|
||||
elif attr_type == bool:
|
||||
elif isinstance(attr_type, bool):
|
||||
schema_type = 'boolean'
|
||||
elif origin_type is list or origin_type is tuple:
|
||||
schema_type = 'array'
|
||||
@@ -599,10 +593,7 @@ def component_schema_generator(attr_cls, description=None):
|
||||
|
||||
|
||||
_SimpleSchema = types.new_class(
|
||||
'_SimpleSchema',
|
||||
(object,),
|
||||
{},
|
||||
lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}),
|
||||
'_SimpleSchema', (object,), {}, lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it})
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,13 +17,8 @@ class OpenLLMSchemaGenerator:
|
||||
|
||||
def apply_schema(func: Callable[P, Any], **attrs: Any) -> Callable[P, Any]: ...
|
||||
def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ...
|
||||
def append_schemas(
|
||||
svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...
|
||||
) -> Service: ...
|
||||
def append_schemas(svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...) -> Service: ...
|
||||
def component_schema_generator(attr_cls: Type[AttrsInstance], description: Optional[str] = ...) -> Dict[str, Any]: ...
|
||||
def get_generator(
|
||||
title: str,
|
||||
components: Optional[List[Type[AttrsInstance]]] = ...,
|
||||
tags: Optional[List[Dict[str, Any]]] = ...,
|
||||
inject: bool = ...,
|
||||
title: str, components: Optional[List[Type[AttrsInstance]]] = ..., tags: Optional[List[Dict[str, Any]]] = ..., inject: bool = ...
|
||||
) -> OpenLLMSchemaGenerator: ...
|
||||
|
||||
@@ -48,17 +48,19 @@ schemas = get_generator(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def jsonify_attr(obj): return json.dumps(converter.unstructure(obj))
|
||||
def jsonify_attr(obj):
|
||||
return json.dumps(converter.unstructure(obj))
|
||||
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(converter.unstructure(CohereErrorResponse(text=message)), status_code=status_code.value)
|
||||
|
||||
|
||||
async def check_model(request, model):
|
||||
if request.model is None or request.model == model: return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see current running models.",
|
||||
)
|
||||
if request.model is None or request.model == model:
|
||||
return None
|
||||
return error_response(HTTPStatus.NOT_FOUND, f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see current running models.")
|
||||
|
||||
|
||||
def mount_to_svc(svc, llm):
|
||||
app = Starlette(
|
||||
@@ -74,6 +76,7 @@ def mount_to_svc(svc, llm):
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append', inject=DEBUG)
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
async def cohere_generate(req, llm):
|
||||
json_str = await req.body()
|
||||
@@ -130,18 +133,14 @@ async def cohere_generate(req, llm):
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
]
|
||||
outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs]
|
||||
)
|
||||
return JSONResponse(
|
||||
converter.unstructure(
|
||||
Generations(
|
||||
id=request_id,
|
||||
generations=[
|
||||
Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason)
|
||||
for output in final_result.outputs
|
||||
Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason) for output in final_result.outputs
|
||||
],
|
||||
)
|
||||
),
|
||||
@@ -165,6 +164,7 @@ def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str
|
||||
messages.append({'role': 'user', 'content': request.message})
|
||||
return messages
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
async def cohere_chat(req, llm):
|
||||
json_str = await req.body()
|
||||
@@ -247,9 +247,7 @@ async def cohere_chat(req, llm):
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)]
|
||||
)
|
||||
final_result = final_result.with_options(outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)])
|
||||
num_prompt_tokens, num_response_tokens = len(final_result.prompt_token_ids), len(token_ids)
|
||||
return JSONResponse(
|
||||
converter.unstructure(
|
||||
|
||||
@@ -14,8 +14,6 @@ from ..protocol.cohere import CohereChatRequest, CohereGenerateRequest
|
||||
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def jsonify_attr(obj: AttrsInstance) -> str: ...
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
|
||||
async def check_model(
|
||||
request: Union[CohereGenerateRequest, CohereChatRequest], model: str
|
||||
) -> Optional[JSONResponse]: ...
|
||||
async def check_model(request: Union[CohereGenerateRequest, CohereChatRequest], model: str) -> Optional[JSONResponse]: ...
|
||||
async def cohere_generate(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
async def cohere_chat(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
|
||||
@@ -21,6 +21,7 @@ schemas = get_generator(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def mount_to_svc(svc, llm):
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
@@ -34,9 +35,11 @@ def mount_to_svc(svc, llm):
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
|
||||
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)), status_code=status_code.value)
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
async def hf_agent(req, llm):
|
||||
json_str = await req.body()
|
||||
@@ -55,9 +58,11 @@ async def hf_agent(req, llm):
|
||||
logger.error('Error while generating: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
def hf_adapters(req, llm):
|
||||
if not llm.has_adapters: return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
|
||||
if not llm.has_adapters:
|
||||
return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
|
||||
return JSONResponse(
|
||||
{
|
||||
adapter_tuple[1]: {'adapter_name': k, 'adapter_type': adapter_tuple[0].peft_type.value}
|
||||
|
||||
@@ -55,7 +55,9 @@ schemas = get_generator(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def jsonify_attr(obj): return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
def jsonify_attr(obj):
|
||||
return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(
|
||||
@@ -63,8 +65,10 @@ def error_response(status_code, message):
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
|
||||
async def check_model(request, model):
|
||||
if request.model == model: return None
|
||||
if request.model == model:
|
||||
return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
@@ -75,11 +79,13 @@ def create_logprobs(token_ids, top_logprobs, num_output_top_logprobs=None, initi
|
||||
# Create OpenAI-style logprobs.
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs: logprobs.top_logprobs = []
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
token_logprob = None
|
||||
if step_top_logprobs is not None: token_logprob = step_top_logprobs[token_id]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id]
|
||||
token = llm.tokenizer.convert_ids_to_tokens(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
@@ -100,23 +106,20 @@ def mount_to_svc(svc, llm):
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route(
|
||||
'/models',
|
||||
functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm),
|
||||
methods=['GET']
|
||||
),
|
||||
Route(
|
||||
'/completions',
|
||||
functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm),
|
||||
methods=['POST'],
|
||||
),
|
||||
Route('/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm), methods=['POST']),
|
||||
Route(
|
||||
'/chat/completions',
|
||||
functools.partial(apply_schema(chat_completions,
|
||||
__model_id__=llm.llm_type,
|
||||
__chat_template__=orjson.dumps(llm.config.chat_template).decode(),
|
||||
__chat_messages__=orjson.dumps(llm.config.chat_messages).decode(),
|
||||
__add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False)), llm=llm),
|
||||
functools.partial(
|
||||
apply_schema(
|
||||
chat_completions,
|
||||
__model_id__=llm.llm_type,
|
||||
__chat_template__=orjson.dumps(llm.config.chat_template).decode(),
|
||||
__chat_messages__=orjson.dumps(llm.config.chat_messages).decode(),
|
||||
__add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False),
|
||||
),
|
||||
llm=llm,
|
||||
),
|
||||
methods=['POST'],
|
||||
),
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
@@ -128,7 +131,9 @@ def mount_to_svc(svc, llm):
|
||||
|
||||
# GET /v1/models
|
||||
@add_schema_definitions
|
||||
def list_models(_, llm): return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
|
||||
def list_models(_, llm):
|
||||
return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
|
||||
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@add_schema_definitions
|
||||
@@ -138,26 +143,36 @@ async def chat_completions(req, llm):
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), ChatCompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err)
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received chat completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, chat_template=request.chat_template if request.chat_template != 'None' else None, add_generation_prompt=request.add_generation_prompt)
|
||||
prompt = llm.tokenizer.apply_chat_template(
|
||||
request.messages,
|
||||
tokenize=False,
|
||||
chat_template=request.chat_template if request.chat_template != 'None' else None,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
def get_role() -> str: return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant' # TODO: Support custom role here.
|
||||
def get_role() -> str:
|
||||
return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant' # TODO: Support custom role here.
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index, text, finish_reason=None, usage=None):
|
||||
@@ -167,25 +182,30 @@ async def chat_completions(req, llm):
|
||||
model=model_name,
|
||||
choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)],
|
||||
)
|
||||
if usage is not None: response.usage = usage
|
||||
if usage is not None:
|
||||
response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
|
||||
async def completion_stream_generator():
|
||||
# first chunk with role
|
||||
role = get_role()
|
||||
for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n'
|
||||
for i in range(config['n']):
|
||||
yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n'
|
||||
|
||||
if request.echo:
|
||||
last_message, last_content = request.messages[-1], ''
|
||||
if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content']
|
||||
if last_message.get('content') and last_message.get('role') == role:
|
||||
last_content = last_message['content']
|
||||
if last_content:
|
||||
for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n'
|
||||
for i in range(config['n']):
|
||||
yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n'
|
||||
|
||||
previous_num_tokens = [0] * config['n']
|
||||
finish_reason_sent = [False] * config['n']
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
if finish_reason_sent[output.index]: continue
|
||||
if finish_reason_sent[output.index]:
|
||||
continue
|
||||
yield f'data: {create_stream_response_json(output.index, output.text)}\n\n'
|
||||
previous_num_tokens[output.index] += len(output.token_ids)
|
||||
if output.finish_reason is not None:
|
||||
@@ -197,35 +217,32 @@ async def chat_completions(req, llm):
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
]
|
||||
outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs]
|
||||
)
|
||||
|
||||
role = get_role()
|
||||
choices = [
|
||||
ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
ChatCompletionResponseChoice(index=output.index, message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason)
|
||||
for output in final_result.outputs
|
||||
]
|
||||
if request.echo:
|
||||
last_message, last_content = request.messages[-1], ''
|
||||
if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content']
|
||||
if last_message.get('content') and last_message.get('role') == role:
|
||||
last_content = last_message['content']
|
||||
for choice in choices:
|
||||
full_message = last_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
@@ -236,7 +253,8 @@ async def chat_completions(req, llm):
|
||||
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
@@ -248,20 +266,26 @@ async def completions(req, llm):
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err)
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received legacy completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
||||
echo_without_generation = request.echo and request.max_tokens == 0
|
||||
if echo_without_generation: request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back.
|
||||
if echo_without_generation:
|
||||
request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back.
|
||||
|
||||
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
if request.suffix is not None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
if not request.prompt: return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
if not request.prompt:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
@@ -272,7 +296,8 @@ async def completions(req, llm):
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
# best_of != n then we don't stream
|
||||
@@ -286,7 +311,8 @@ async def completions(req, llm):
|
||||
model=model_name,
|
||||
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)],
|
||||
)
|
||||
if usage: response.usage = usage
|
||||
if usage:
|
||||
response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
|
||||
async def completion_stream_generator():
|
||||
@@ -301,7 +327,7 @@ async def completions(req, llm):
|
||||
logprobs = None
|
||||
top_logprobs = None
|
||||
if request.logprobs is not None:
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i] :]
|
||||
|
||||
if request.echo and not previous_echo[i]:
|
||||
if not echo_without_generation:
|
||||
@@ -316,7 +342,7 @@ async def completions(req, llm):
|
||||
top_logprobs = res.prompt_logprobs
|
||||
previous_echo[i] = True
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], request.logprobs, len(previous_texts[i]), llm=llm)
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i] :], request.logprobs, len(previous_texts[i]), llm=llm)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
previous_texts[i] += output.text
|
||||
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
|
||||
@@ -329,21 +355,21 @@ async def completions(req, llm):
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
if stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
]
|
||||
outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs]
|
||||
)
|
||||
|
||||
choices = []
|
||||
@@ -355,13 +381,15 @@ async def completions(req, llm):
|
||||
if request.logprobs is not None:
|
||||
if not echo_without_generation:
|
||||
token_ids, top_logprobs = output.token_ids, output.logprobs
|
||||
if request.echo: token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs
|
||||
if request.echo:
|
||||
token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs
|
||||
else:
|
||||
token_ids, top_logprobs = prompt_token_ids, prompt_logprobs
|
||||
logprobs = create_logprobs(token_ids, top_logprobs, request.logprobs, llm=llm)
|
||||
if not echo_without_generation:
|
||||
output_text = output.text
|
||||
if request.echo: output_text = prompt_text + output_text
|
||||
if request.echo:
|
||||
output_text = prompt_text + output_text
|
||||
else:
|
||||
output_text = prompt_text
|
||||
choice_data = CompletionResponseChoice(index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason)
|
||||
|
||||
@@ -14,12 +14,14 @@ from ..protocol.openai import ChatCompletionRequest, CompletionRequest, LogProbs
|
||||
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def jsonify_attr(obj: AttrsInstance) -> str: ...
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
|
||||
async def check_model(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest], model: str
|
||||
) -> Optional[JSONResponse]: ...
|
||||
async def check_model(request: Union[CompletionRequest, ChatCompletionRequest], model: str) -> Optional[JSONResponse]: ...
|
||||
def create_logprobs(
|
||||
token_ids: List[int], top_logprobs: List[Dict[int, float]], #
|
||||
num_output_top_logprobs: Optional[int] = ..., initial_text_offset: int = ..., *, llm: LLM[M, T]
|
||||
token_ids: List[int],
|
||||
top_logprobs: List[Dict[int, float]], #
|
||||
num_output_top_logprobs: Optional[int] = ...,
|
||||
initial_text_offset: int = ...,
|
||||
*,
|
||||
llm: LLM[M, T],
|
||||
) -> LogProbs: ...
|
||||
def list_models(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from openllm_core.exceptions import (
|
||||
Error as Error, FineTuneStrategyNotSupportedError as FineTuneStrategyNotSupportedError, #
|
||||
ForbiddenAttributeError as ForbiddenAttributeError, GpuNotAvailableError as GpuNotAvailableError, #
|
||||
OpenLLMException as OpenLLMException, ValidationError as ValidationError, #
|
||||
Error as Error,
|
||||
FineTuneStrategyNotSupportedError as FineTuneStrategyNotSupportedError, #
|
||||
ForbiddenAttributeError as ForbiddenAttributeError,
|
||||
GpuNotAvailableError as GpuNotAvailableError, #
|
||||
OpenLLMException as OpenLLMException,
|
||||
ValidationError as ValidationError, #
|
||||
MissingAnnotationAttributeError as MissingAnnotationAttributeError,
|
||||
MissingDependencyError as MissingDependencyError,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import openllm, transformers
|
||||
import openllm, transformers, typing as t
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM, config: transformers.PretrainedConfig, **attrs: t.Any): ...
|
||||
|
||||
@@ -5,6 +5,7 @@ import typing as t
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': [], 'cohere': [], 'hf': []}
|
||||
if t.TYPE_CHECKING: from . import cohere as cohere, hf as hf, openai as openai
|
||||
if t.TYPE_CHECKING:
|
||||
from . import cohere as cohere, hf as hf, openai as openai
|
||||
__lazy = LazyModule(__name__, os.path.abspath('__file__'), _import_structure)
|
||||
__all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__
|
||||
|
||||
@@ -17,10 +17,13 @@ class ErrorResponse:
|
||||
param: t.Optional[str] = None
|
||||
code: t.Optional[str] = None
|
||||
|
||||
|
||||
def _stop_converter(data: t.Union[str, t.List[str]]) -> t.List[str]:
|
||||
if not data: return None
|
||||
if not data:
|
||||
return None
|
||||
return [data] if isinstance(data, str) else data
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionRequest:
|
||||
prompt: str
|
||||
|
||||
@@ -3,10 +3,13 @@ import importlib, typing as t
|
||||
from openllm_core._typing_compat import M, ParamSpec, T, TypeGuard, Concatenate
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
|
||||
if t.TYPE_CHECKING: from bentoml import Model; from .._llm import LLM
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml import Model
|
||||
from .._llm import LLM
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
|
||||
def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]:
|
||||
import cloudpickle, fs, transformers
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
@@ -38,16 +41,17 @@ def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[LLM[M, T], P], TypeGuard[M | T | Model]]:
|
||||
def caller(llm: LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> TypeGuard[M | T | Model]:
|
||||
'''Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")'
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"'
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ctranslate' if 'llm.__llm_backend__="ctranslate"'
|
||||
'''
|
||||
"""
|
||||
if llm.__llm_backend__ == 'ggml':
|
||||
serde = 'ggml'
|
||||
elif llm.__llm_backend__ == 'ctranslate':
|
||||
@@ -57,12 +61,19 @@ def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[LLM[M, T], P], Ty
|
||||
else:
|
||||
raise OpenLLMException(f'Not supported backend {llm.__llm_backend__}')
|
||||
return getattr(importlib.import_module(f'.{serde}', 'openllm.serialisation'), fn)(llm, *args, **kwargs)
|
||||
|
||||
return caller
|
||||
|
||||
|
||||
_extras = ['get', 'import_model', 'load_model']
|
||||
_import_structure = {'ggml', 'transformers', 'ctranslate', 'constants'}
|
||||
__all__ = ['load_tokenizer', *_extras, *_import_structure]
|
||||
def __dir__() -> t.Sequence[str]: return sorted(__all__)
|
||||
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == 'load_tokenizer':
|
||||
return load_tokenizer
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
'''Serialisation utilities for OpenLLM.
|
||||
"""Serialisation utilities for OpenLLM.
|
||||
|
||||
Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
'''
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from bentoml import Model
|
||||
from openllm import LLM
|
||||
@@ -11,11 +12,12 @@ from openllm_core._typing_compat import M, T
|
||||
from . import constants as constants, ggml as ggml, transformers as transformers
|
||||
|
||||
def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T:
|
||||
'''Load the tokenizer from BentoML store.
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
'''
|
||||
"""
|
||||
|
||||
def get(llm: LLM[M, T]) -> Model: ...
|
||||
def import_model(llm: LLM[M, T], *args: Any, trust_remote_code: bool, **attrs: Any) -> Model: ...
|
||||
def load_model(llm: LLM[M, T], *args: Any, **attrs: Any) -> M: ...
|
||||
|
||||
@@ -8,61 +8,82 @@ from openllm_core.utils import is_autogptq_available
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def get_hash(config) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}')
|
||||
if _commit_hash is None:
|
||||
raise ValueError(f'Cannot find commit hash in {config}')
|
||||
return _commit_hash
|
||||
|
||||
|
||||
def patch_correct_tag(llm, config, _revision=None) -> None:
|
||||
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
|
||||
if llm.revision is not None: return
|
||||
if llm.revision is not None:
|
||||
return
|
||||
if not llm.local:
|
||||
try:
|
||||
if _revision is None: _revision = get_hash(config)
|
||||
if _revision is None:
|
||||
_revision = get_hash(config)
|
||||
except ValueError:
|
||||
pass
|
||||
if _revision is None and llm.tag.version is not None: _revision = llm.tag.version
|
||||
if _revision is None and llm.tag.version is not None:
|
||||
_revision = llm.tag.version
|
||||
if llm.tag.version is None:
|
||||
_object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
|
||||
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
_object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
|
||||
if llm._revision is None:
|
||||
_object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
|
||||
|
||||
def _create_metadata(llm, config, safe_serialisation, trust_remote_code, metadata=None):
|
||||
if metadata is None: metadata = {}
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata.update({'safe_serialisation': safe_serialisation, '_framework': llm.__llm_backend__})
|
||||
if llm.quantise: metadata['_quantize'] = llm.quantise
|
||||
if llm.quantise:
|
||||
metadata['_quantize'] = llm.quantise
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures:
|
||||
if trust_remote_code:
|
||||
auto_map = getattr(config, 'auto_map', {})
|
||||
if not auto_map: raise RuntimeError(f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}')
|
||||
if not auto_map:
|
||||
raise RuntimeError(f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}')
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if autoclass not in auto_map:
|
||||
raise RuntimeError(f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models.")
|
||||
raise RuntimeError(
|
||||
f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models."
|
||||
)
|
||||
architectures = [auto_map[autoclass]]
|
||||
else:
|
||||
raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
raise RuntimeError(
|
||||
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
|
||||
)
|
||||
metadata.update({'_pretrained_class': architectures[0], '_revision': get_hash(config) if not llm.local else llm.revision})
|
||||
return metadata
|
||||
|
||||
|
||||
def _create_signatures(llm, signatures=None):
|
||||
if signatures is None: signatures = {}
|
||||
if signatures is None:
|
||||
signatures = {}
|
||||
if llm.__llm_backend__ == 'pt':
|
||||
if llm.quantise == 'gptq':
|
||||
if not is_autogptq_available():
|
||||
raise OpenLLMException("Requires 'auto-gptq' and 'optimum'. Install it with 'pip install \"openllm[gptq]\"'")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
signatures.update(
|
||||
{
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in (
|
||||
'__call__', 'forward', 'generate', #
|
||||
'contrastive_search', 'greedy_search', #
|
||||
'sample', 'beam_search', 'beam_sample', #
|
||||
'group_beam_search', 'constrained_beam_search', #
|
||||
)
|
||||
}
|
||||
)
|
||||
signatures.update({
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in (
|
||||
'__call__',
|
||||
'forward',
|
||||
'generate', #
|
||||
'contrastive_search',
|
||||
'greedy_search', #
|
||||
'sample',
|
||||
'beam_search',
|
||||
'beam_sample', #
|
||||
'group_beam_search',
|
||||
'constrained_beam_search', #
|
||||
)
|
||||
})
|
||||
elif llm.__llm_backend__ == 'ctranslate':
|
||||
if llm.config['model_type'] == 'seq2seq_lm':
|
||||
non_batch_keys = {'score_file', 'translate_file'}
|
||||
@@ -70,24 +91,37 @@ def _create_signatures(llm, signatures=None):
|
||||
else:
|
||||
non_batch_keys = set()
|
||||
batch_keys = {
|
||||
'async_generate_tokens', 'forward_batch', 'generate_batch', #
|
||||
'generate_iterable', 'generate_tokens', 'score_batch', 'score_iterable', #
|
||||
'async_generate_tokens',
|
||||
'forward_batch',
|
||||
'generate_batch', #
|
||||
'generate_iterable',
|
||||
'generate_tokens',
|
||||
'score_batch',
|
||||
'score_iterable', #
|
||||
}
|
||||
signatures.update({k: ModelSignature(batchable=False) for k in non_batch_keys})
|
||||
signatures.update({k: ModelSignature(batchable=True) for k in batch_keys})
|
||||
return signatures
|
||||
|
||||
|
||||
@inject
|
||||
@contextlib.contextmanager
|
||||
def save_model(
|
||||
llm, config, safe_serialisation, #
|
||||
trust_remote_code, module, external_modules, #
|
||||
_model_store=Provide[BentoMLContainer.model_store], _api_version='v2.1.0', #
|
||||
llm,
|
||||
config,
|
||||
safe_serialisation, #
|
||||
trust_remote_code,
|
||||
module,
|
||||
external_modules, #
|
||||
_model_store=Provide[BentoMLContainer.model_store],
|
||||
_api_version='v2.1.0', #
|
||||
):
|
||||
imported_modules = []
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag, module=f'openllm.serialisation.{module}', #
|
||||
api_version=_api_version, options=ModelOptions(), #
|
||||
llm.tag,
|
||||
module=f'openllm.serialisation.{module}', #
|
||||
api_version=_api_version,
|
||||
options=ModelOptions(), #
|
||||
context=openllm.utils.generate_context('openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
metadata=_create_metadata(llm, config, safe_serialisation, trust_remote_code),
|
||||
@@ -103,9 +137,7 @@ def save_model(
|
||||
bentomodel.flush()
|
||||
bentomodel.save(_model_store)
|
||||
openllm.utils.analytics.track(
|
||||
openllm.utils.analytics.ModelSaveEvent(
|
||||
module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024
|
||||
)
|
||||
openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024)
|
||||
)
|
||||
finally:
|
||||
bentomodel.exit_cloudpickle_context(imported_modules)
|
||||
|
||||
@@ -10,9 +10,7 @@ from openllm_core._typing_compat import M, T
|
||||
from .._llm import LLM
|
||||
|
||||
def get_hash(config: transformers.PretrainedConfig) -> str: ...
|
||||
def patch_correct_tag(
|
||||
llm: LLM[M, T], config: transformers.PretrainedConfig, _revision: Optional[str] = ...
|
||||
) -> None: ...
|
||||
def patch_correct_tag(llm: LLM[M, T], config: transformers.PretrainedConfig, _revision: Optional[str] = ...) -> None: ...
|
||||
@contextmanager
|
||||
def save_model(
|
||||
llm: LLM[M, T],
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
HUB_ATTRS = [
|
||||
'cache_dir', 'code_revision', 'force_download', #
|
||||
'local_files_only', 'proxies', 'resume_download', #
|
||||
'revision', 'subfolder', 'use_auth_token', #
|
||||
'cache_dir',
|
||||
'code_revision',
|
||||
'force_download', #
|
||||
'local_files_only',
|
||||
'proxies',
|
||||
'resume_download', #
|
||||
'revision',
|
||||
'subfolder',
|
||||
'use_auth_token', #
|
||||
]
|
||||
CONFIG_FILE_NAME = 'config.json'
|
||||
# the below is similar to peft.utils.other.CONFIG_NAME
|
||||
|
||||
@@ -12,9 +12,7 @@ from .._helpers import patch_correct_tag, save_model
|
||||
from ..transformers._helpers import get_tokenizer, process_config
|
||||
|
||||
if not is_ctranslate_available():
|
||||
raise RuntimeError(
|
||||
"'ctranslate2' is required to use with backend 'ctranslate'. Install it with 'pip install \"openllm[ctranslate]\"'"
|
||||
)
|
||||
raise RuntimeError("'ctranslate2' is required to use with backend 'ctranslate'. Install it with 'pip install \"openllm[ctranslate]\"'")
|
||||
|
||||
import ctranslate2
|
||||
from ctranslate2.converters.transformers import TransformersConverter
|
||||
@@ -44,17 +42,11 @@ def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
patch_correct_tag(llm, config)
|
||||
tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
with save_model(
|
||||
llm, config, False, trust_remote_code, 'ctranslate', [importlib.import_module(tokenizer.__module__)]
|
||||
) as save_metadata:
|
||||
with save_model(llm, config, False, trust_remote_code, 'ctranslate', [importlib.import_module(tokenizer.__module__)]) as save_metadata:
|
||||
bentomodel, _ = save_metadata
|
||||
if llm._local:
|
||||
shutil.copytree(
|
||||
llm.model_id,
|
||||
bentomodel.path,
|
||||
symlinks=False,
|
||||
ignore=shutil.ignore_patterns('.git', 'venv', '__pycache__', '.venv'),
|
||||
dirs_exist_ok=True,
|
||||
llm.model_id, bentomodel.path, symlinks=False, ignore=shutil.ignore_patterns('.git', 'venv', '__pycache__', '.venv'), dirs_exist_ok=True
|
||||
)
|
||||
else:
|
||||
TransformersConverter(
|
||||
@@ -74,9 +66,7 @@ def get(llm):
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__:
|
||||
raise OpenLLMException(
|
||||
f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'."
|
||||
)
|
||||
raise OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
|
||||
patch_correct_tag(
|
||||
llm,
|
||||
transformers.AutoConfig.from_pretrained(model.path_of('/hf/'), trust_remote_code=llm.trust_remote_code),
|
||||
|
||||
@@ -15,15 +15,18 @@ logger = logging.getLogger(__name__)
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
(_base_decls, _base_attrs), tokenizer_attrs = llm.llm_parameters
|
||||
decls = (*_base_decls, *decls)
|
||||
attrs = {**_base_attrs, **attrs}
|
||||
if llm._local: logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
|
||||
if llm._local:
|
||||
logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
patch_correct_tag(llm, config)
|
||||
safe_serialisation = first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
if llm.quantise != 'gptq': attrs['use_safetensors'] = safe_serialisation
|
||||
if llm.quantise != 'gptq':
|
||||
attrs['use_safetensors'] = safe_serialisation
|
||||
|
||||
model = None
|
||||
tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
@@ -36,31 +39,30 @@ def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
if llm.quantise == 'gptq':
|
||||
from optimum.gptq.constants import GPTQ_CONFIG
|
||||
|
||||
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local: # possible local path
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm.model_id,
|
||||
*decls,
|
||||
local_files_only=True,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
llm.model_id, *decls, local_files_only=True, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
|
||||
)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
|
||||
del model
|
||||
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(
|
||||
llm.model_id, local_dir=bentomodel.path, #
|
||||
local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm), #
|
||||
llm.model_id,
|
||||
local_dir=bentomodel.path, #
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=HfIgnore.ignore_patterns(llm), #
|
||||
)
|
||||
return bentomodel
|
||||
|
||||
|
||||
def get(llm):
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
@@ -79,15 +81,18 @@ def get(llm):
|
||||
|
||||
def check_unintialised_params(model):
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
|
||||
def load_model(llm, *decls, **attrs):
|
||||
if llm.quantise in {'awq', 'squeezellm'}: raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
if llm.quantise in {'awq', 'squeezellm'}:
|
||||
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
if llm.__llm_backend__ == 'triton': return openllm.models.load_model(llm, config, **attrs)
|
||||
if llm.__llm_backend__ == 'triton':
|
||||
return openllm.models.load_model(llm, config, **attrs)
|
||||
|
||||
auto_class = infer_autoclass_from_llm(llm, config)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
@@ -111,33 +116,30 @@ def load_model(llm, *decls, **attrs):
|
||||
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, use_flash_attention_2=is_flash_attn_2_available(), **attrs
|
||||
llm.bentomodel.path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
use_flash_attention_2=is_flash_attn_2_available(),
|
||||
**attrs,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs)
|
||||
else:
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
use_flash_attention_2=is_flash_attn_2_available(),
|
||||
**attrs,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **attrs
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
use_flash_attention_2=is_flash_attn_2_available(),
|
||||
**attrs,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
check_unintialised_params(model)
|
||||
return model
|
||||
|
||||
@@ -3,26 +3,36 @@ import transformers
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tokenizer(model_id_or_path, trust_remote_code, **attrs):
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code, **attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def process_config(model_id, trust_remote_code, **attrs):
|
||||
config = attrs.pop('config', None)
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
if not isinstance(config, transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto':
|
||||
copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs
|
||||
)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
|
||||
def infer_autoclass_from_llm(llm, config, /):
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if llm.trust_remote_code:
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
raise ValueError(
|
||||
f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping'
|
||||
)
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map:
|
||||
|
||||
@@ -11,12 +11,14 @@ if t.TYPE_CHECKING:
|
||||
__global_inst__ = None
|
||||
__cached_id__: dict[str, HfModelInfo] = dict()
|
||||
|
||||
|
||||
def Client() -> HfApi:
|
||||
global __global_inst__ # noqa: PLW0603
|
||||
global __global_inst__
|
||||
if __global_inst__ is None:
|
||||
__global_inst__ = HfApi()
|
||||
return __global_inst__
|
||||
|
||||
|
||||
def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
if model_id in __cached_id__:
|
||||
return __cached_id__[model_id]
|
||||
@@ -27,14 +29,17 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
traceback.print_exc()
|
||||
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
|
||||
|
||||
|
||||
def has_weights(model_id: str, revision: str | None = None, *, extensions: str) -> bool:
|
||||
if validate_is_path(model_id):
|
||||
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob(f'*.{extensions}')), False)
|
||||
return any(s.rfilename.endswith(f'.{extensions}') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
|
||||
|
||||
has_safetensors_weights = functools.partial(has_weights, extensions='safetensors')
|
||||
has_pt_weights = functools.partial(has_weights, extensions='pt')
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class HfIgnore:
|
||||
safetensors = '*.safetensors'
|
||||
@@ -42,11 +47,12 @@ class HfIgnore:
|
||||
tf = '*.h5'
|
||||
flax = '*.msgpack'
|
||||
gguf = '*.gguf'
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm
|
||||
if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm
|
||||
base.append(cls.safetensors)
|
||||
elif has_safetensors_weights(llm.model_id):
|
||||
base.extend([cls.pt, '*.pt'])
|
||||
|
||||
@@ -1,16 +1,36 @@
|
||||
import functools, importlib.metadata, openllm_core
|
||||
|
||||
__all__ = ['generate_labels', 'available_devices', 'device_count']
|
||||
|
||||
|
||||
def generate_labels(llm):
|
||||
return {
|
||||
'backend': llm.__llm_backend__, 'framework': 'openllm', 'model_name': llm.config['model_name'], #
|
||||
'architecture': llm.config['architecture'], 'serialisation': llm._serialisation, #
|
||||
'backend': llm.__llm_backend__,
|
||||
'framework': 'openllm',
|
||||
'model_name': llm.config['model_name'], #
|
||||
'architecture': llm.config['architecture'],
|
||||
'serialisation': llm._serialisation, #
|
||||
**{package: importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}},
|
||||
}
|
||||
def available_devices(): from ._strategies import NvidiaGpuResource; return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
|
||||
def available_devices():
|
||||
from ._strategies import NvidiaGpuResource
|
||||
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def device_count() -> int: return len(available_devices())
|
||||
def device_count() -> int:
|
||||
return len(available_devices())
|
||||
|
||||
|
||||
def __dir__():
|
||||
coreutils = set(dir(openllm_core.utils)) | set([it for it in openllm_core.utils._extras if not it.startswith('_')]); return sorted(__all__) + sorted(list(coreutils))
|
||||
coreutils = set(dir(openllm_core.utils)) | set([it for it in openllm_core.utils._extras if not it.startswith('_')])
|
||||
return sorted(__all__) + sorted(list(coreutils))
|
||||
|
||||
|
||||
def __getattr__(it):
|
||||
if hasattr(openllm_core.utils, it): return getattr(openllm_core.utils, it)
|
||||
if hasattr(openllm_core.utils, it):
|
||||
return getattr(openllm_core.utils, it)
|
||||
raise AttributeError(f'module {__name__} has no attribute {it}')
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
'''OpenLLM CLI.
|
||||
"""OpenLLM CLI.
|
||||
|
||||
For more information see ``openllm -h``.
|
||||
'''
|
||||
"""
|
||||
|
||||
@@ -5,24 +5,12 @@ from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from click import shell_completion as sc
|
||||
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._typing_compat import (
|
||||
Concatenate,
|
||||
DictStrAny,
|
||||
LiteralBackend,
|
||||
LiteralSerialisation,
|
||||
ParamSpec,
|
||||
AnyCallable,
|
||||
get_literal_args,
|
||||
)
|
||||
from openllm_core._typing_compat import Concatenate, DictStrAny, LiteralBackend, LiteralSerialisation, ParamSpec, AnyCallable, get_literal_args
|
||||
from openllm_core.utils import DEBUG, compose, dantic, resolve_user_filepath
|
||||
|
||||
|
||||
class _OpenLLM_GenericInternalConfig(LLMConfig):
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'default_id': 'openllm/generic',
|
||||
'model_ids': ['openllm/generic'],
|
||||
'architecture': 'PreTrainedModel',
|
||||
}
|
||||
__config__ = {'name_type': 'lowercase', 'default_id': 'openllm/generic', 'model_ids': ['openllm/generic'], 'architecture': 'PreTrainedModel'}
|
||||
|
||||
class GenerationConfig:
|
||||
top_k: int = 15
|
||||
@@ -30,6 +18,7 @@ class _OpenLLM_GenericInternalConfig(LLMConfig):
|
||||
temperature: float = 0.75
|
||||
max_new_tokens: int = 128
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec('P')
|
||||
@@ -38,6 +27,7 @@ LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
|
||||
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [
|
||||
sc.CompletionItem(str(it.tag), help='Bento')
|
||||
@@ -45,20 +35,13 @@ def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete
|
||||
if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})
|
||||
]
|
||||
|
||||
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [
|
||||
sc.CompletionItem(inflection.dasherize(it), help='Model')
|
||||
for it in openllm.CONFIG_MAPPING
|
||||
if it.startswith(incomplete)
|
||||
]
|
||||
return [sc.CompletionItem(inflection.dasherize(it), help='Model') for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
|
||||
|
||||
|
||||
def parse_config_options(
|
||||
config: LLMConfig,
|
||||
server_timeout: int,
|
||||
workers_per_resource: float,
|
||||
device: t.Tuple[str, ...] | None,
|
||||
cors: bool,
|
||||
environ: DictStrAny,
|
||||
config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny
|
||||
) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
@@ -72,26 +55,16 @@ def parse_config_options(
|
||||
]
|
||||
if device:
|
||||
if len(device) > 1:
|
||||
_bentoml_config_options_opts.extend(
|
||||
[
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}'
|
||||
for idx, dev in enumerate(device)
|
||||
]
|
||||
)
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)
|
||||
])
|
||||
else:
|
||||
_bentoml_config_options_opts.append(
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]'
|
||||
)
|
||||
_bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend(
|
||||
['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"']
|
||||
)
|
||||
_bentoml_config_options_opts.extend(
|
||||
[
|
||||
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"'
|
||||
for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
|
||||
]
|
||||
)
|
||||
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
|
||||
])
|
||||
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
|
||||
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
|
||||
if DEBUG:
|
||||
@@ -119,22 +92,27 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = name
|
||||
return None
|
||||
|
||||
|
||||
def optimization_decorator(fn: FC, *, factory=click, _eager=True) -> FC | list[AnyCallable]:
|
||||
shared = [
|
||||
dtype_option(factory=factory), model_version_option(factory=factory), #
|
||||
backend_option(factory=factory), quantize_option(factory=factory), #
|
||||
dtype_option(factory=factory),
|
||||
model_version_option(factory=factory), #
|
||||
backend_option(factory=factory),
|
||||
quantize_option(factory=factory), #
|
||||
serialisation_option(factory=factory),
|
||||
]
|
||||
if not _eager: return shared
|
||||
if not _eager:
|
||||
return shared
|
||||
return compose(*shared)(fn)
|
||||
|
||||
|
||||
def start_decorator(fn: FC) -> FC:
|
||||
composed = compose(
|
||||
_OpenLLM_GenericInternalConfig.parse,
|
||||
parse_serve_args(),
|
||||
cog.optgroup.group(
|
||||
'LLM Options',
|
||||
help='''The following options are related to running LLM Server as well as optimization options.
|
||||
help="""The following options are related to running LLM Server as well as optimization options.
|
||||
|
||||
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
|
||||
@@ -142,7 +120,7 @@ def start_decorator(fn: FC) -> FC:
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
''',
|
||||
""",
|
||||
),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
@@ -163,12 +141,14 @@ def start_decorator(fn: FC) -> FC:
|
||||
|
||||
return composed(fn)
|
||||
|
||||
|
||||
def parse_device_callback(_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
|
||||
if value is None:
|
||||
return value
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all': return tuple(map(str, openllm.utils.available_devices()))
|
||||
if len(el) == 1 and el[0] == 'all':
|
||||
return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
|
||||
@@ -182,15 +162,12 @@ def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[F
|
||||
from bentoml_cli.cli import cli
|
||||
|
||||
group = cog.optgroup.group('Start a HTTP server options', help='Related to serving the model [synonymous to `bentoml serve-http`]')
|
||||
|
||||
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands['serve']
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
serve_options = [
|
||||
p
|
||||
for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS]
|
||||
if p.name not in _IGNORED_OPTIONS
|
||||
]
|
||||
serve_options = [p for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
@@ -202,14 +179,16 @@ def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[F
|
||||
param_decls = (*attrs.pop('opts'), *attrs.pop('secondary_opts'))
|
||||
f = cog.optgroup.option(*param_decls, **attrs)(f)
|
||||
return group(f)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
'''General ``@click`` decorator with some sauce.
|
||||
"""General ``@click`` decorator with some sauce.
|
||||
|
||||
This decorator extends the default ``@click.option`` plus a factory option and factory attr to
|
||||
provide type-safe click.option or click.argument wrapper for all compatible factory.
|
||||
'''
|
||||
"""
|
||||
factory = attrs.pop('factory', click)
|
||||
factory_attr = attrs.pop('attr', 'option')
|
||||
if factory_attr != 'argument':
|
||||
@@ -242,18 +221,14 @@ def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callab
|
||||
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--cors/--no-cors',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar='OPENLLM_CORS',
|
||||
show_envvar=True,
|
||||
help='Enable CORS for the server.',
|
||||
**attrs,
|
||||
'--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs
|
||||
)(f)
|
||||
|
||||
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
|
||||
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--dtype',
|
||||
@@ -264,6 +239,7 @@ def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[F
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--model-id',
|
||||
@@ -294,16 +270,14 @@ def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='Runtime to use for both serialisation/inference engine.',
|
||||
**attrs)(f)
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]),
|
||||
required=required,
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
|
||||
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--quantise',
|
||||
@@ -313,7 +287,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
help='''Dynamic quantization for running this LLM.
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
@@ -328,23 +302,25 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
'''
|
||||
"""
|
||||
+ (
|
||||
'''
|
||||
> [!NOTE] that this will set the mode for serving within deployment.''' if build else ''
|
||||
"""
|
||||
> [!NOTE] that this will set the mode for serving within deployment."""
|
||||
if build
|
||||
else ''
|
||||
),
|
||||
**attrs)(f)
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
def workers_per_resource_option(
|
||||
f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any
|
||||
) -> t.Callable[[FC], FC]:
|
||||
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--workers-per-resource',
|
||||
default=None,
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help='''Number of workers per resource assigned.
|
||||
help="""Number of workers per resource assigned.
|
||||
|
||||
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -354,7 +330,7 @@ def workers_per_resource_option(
|
||||
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
|
||||
- ``conserved``: This will determine the number of available GPU resources. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
|
||||
'''
|
||||
"""
|
||||
+ (
|
||||
"""\n
|
||||
> [!NOTE] The workers value passed into 'build' will determine how the LLM can
|
||||
@@ -366,6 +342,7 @@ def workers_per_resource_option(
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--serialisation',
|
||||
@@ -376,7 +353,7 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help='''Serialisation format for save/load LLM.
|
||||
help="""Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
@@ -385,12 +362,14 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
|
||||
|
||||
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
||||
''',
|
||||
""",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
_wpr_strategies = {'round_robin', 'conserved'}
|
||||
|
||||
|
||||
def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
@@ -402,9 +381,7 @@ def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, va
|
||||
float(value) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
raise click.BadParameter(
|
||||
f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.",
|
||||
ctx,
|
||||
param,
|
||||
f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param
|
||||
) from None
|
||||
else:
|
||||
return value
|
||||
|
||||
@@ -6,6 +6,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm_core._typing_compat import LiteralSerialisation
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import WARNING_ENV_VAR, codegen, first_not_none, get_disable_warnings, is_vllm_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from openllm_core._configuration import LLMConfig
|
||||
@@ -13,6 +14,7 @@ if t.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _start(
|
||||
model_id: str,
|
||||
timeout: int = 30,
|
||||
@@ -61,27 +63,25 @@ def _start(
|
||||
additional_args: Additional arguments to pass to ``openllm start``.
|
||||
"""
|
||||
from .entrypoint import start_command
|
||||
|
||||
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
|
||||
args: list[str] = [model_id]
|
||||
if timeout: args.extend(['--server-timeout', str(timeout)])
|
||||
if timeout:
|
||||
args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
args.extend(
|
||||
[
|
||||
'--workers-per-resource',
|
||||
str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource,
|
||||
]
|
||||
)
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'): args.extend(['--device', ','.join(device)])
|
||||
if quantize: args.extend(['--quantize', str(quantize)])
|
||||
if cors: args.append('--cors')
|
||||
args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'):
|
||||
args.extend(['--device', ','.join(device)])
|
||||
if quantize:
|
||||
args.extend(['--quantize', str(quantize)])
|
||||
if cors:
|
||||
args.append('--cors')
|
||||
if adapter_map:
|
||||
args.extend(
|
||||
list(
|
||||
itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])
|
||||
)
|
||||
)
|
||||
if additional_args: args.extend(additional_args)
|
||||
if __test__: args.append('--return-process')
|
||||
args.extend(list(itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])))
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if __test__:
|
||||
args.append('--return-process')
|
||||
|
||||
cmd = start_command
|
||||
return cmd.main(args=args, standalone_mode=False)
|
||||
@@ -138,6 +138,7 @@ def _build(
|
||||
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
args: list[str] = [
|
||||
sys.executable,
|
||||
'-m',
|
||||
@@ -147,23 +148,34 @@ def _build(
|
||||
'--machine',
|
||||
'--quiet',
|
||||
'--serialisation',
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'),
|
||||
]
|
||||
if quantize: args.extend(['--quantize', quantize])
|
||||
if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push: args.extend(['--push'])
|
||||
if containerize: args.extend(['--containerize'])
|
||||
if build_ctx: args.extend(['--build-ctx', build_ctx])
|
||||
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite: args.append('--overwrite')
|
||||
if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version: args.extend(['--model-version', model_version])
|
||||
if bento_version: args.extend(['--bento-version', bento_version])
|
||||
if dockerfile_template: args.extend(['--dockerfile-template', dockerfile_template])
|
||||
if additional_args: args.extend(additional_args)
|
||||
if force_push: args.append('--force-push')
|
||||
if quantize:
|
||||
args.extend(['--quantize', quantize])
|
||||
if containerize and push:
|
||||
raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push:
|
||||
args.extend(['--push'])
|
||||
if containerize:
|
||||
args.extend(['--containerize'])
|
||||
if build_ctx:
|
||||
args.extend(['--build-ctx', build_ctx])
|
||||
if enable_features:
|
||||
args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite:
|
||||
args.append('--overwrite')
|
||||
if adapter_map:
|
||||
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version:
|
||||
args.extend(['--model-version', model_version])
|
||||
if bento_version:
|
||||
args.extend(['--bento-version', bento_version])
|
||||
if dockerfile_template:
|
||||
args.extend(['--dockerfile-template', dockerfile_template])
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if force_push:
|
||||
args.append('--force-push')
|
||||
|
||||
current_disable_warning = get_disable_warnings()
|
||||
os.environ[WARNING_ENV_VAR] = str(True)
|
||||
@@ -171,17 +183,24 @@ def _build(
|
||||
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("Exception caught while building Bento for '%s'", model_id, exc_info=e)
|
||||
if e.stderr: raise OpenLLMException(e.stderr.decode('utf-8')) from None
|
||||
if e.stderr:
|
||||
raise OpenLLMException(e.stderr.decode('utf-8')) from None
|
||||
raise OpenLLMException(str(e)) from None
|
||||
matched = re.match(r'__object__:(\{.*\})$', output.decode('utf-8').strip())
|
||||
if matched is None:
|
||||
raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
|
||||
raise ValueError(
|
||||
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
)
|
||||
os.environ[WARNING_ENV_VAR] = str(current_disable_warning)
|
||||
try:
|
||||
result = orjson.loads(matched.group(1))
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.") from e
|
||||
raise ValueError(
|
||||
f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
) from e
|
||||
return bentoml.get(result['tag'], _bento_store=bento_store)
|
||||
|
||||
|
||||
def _import_model(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
@@ -218,15 +237,32 @@ def _import_model(
|
||||
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from .entrypoint import import_command
|
||||
|
||||
args = [model_id, '--quiet']
|
||||
if backend is not None: args.extend(['--backend', backend])
|
||||
if model_version is not None: args.extend(['--model-version', str(model_version)])
|
||||
if quantize is not None: args.extend(['--quantize', quantize])
|
||||
if serialisation is not None: args.extend(['--serialisation', serialisation])
|
||||
if additional_args is not None: args.extend(additional_args)
|
||||
if backend is not None:
|
||||
args.extend(['--backend', backend])
|
||||
if model_version is not None:
|
||||
args.extend(['--model-version', str(model_version)])
|
||||
if quantize is not None:
|
||||
args.extend(['--quantize', quantize])
|
||||
if serialisation is not None:
|
||||
args.extend(['--serialisation', serialisation])
|
||||
if additional_args is not None:
|
||||
args.extend(additional_args)
|
||||
return import_command.main(args=args, standalone_mode=False)
|
||||
|
||||
|
||||
def _list_models() -> dict[str, t.Any]:
|
||||
'''List all available models within the local store.'''
|
||||
from .entrypoint import models_command; return models_command.main(args=['--quiet'], standalone_mode=False)
|
||||
start, build, import_model, list_models = codegen.gen_sdk(_start), codegen.gen_sdk(_build), codegen.gen_sdk(_import_model), codegen.gen_sdk(_list_models)
|
||||
"""List all available models within the local store."""
|
||||
from .entrypoint import models_command
|
||||
|
||||
return models_command.main(args=['--quiet'], standalone_mode=False)
|
||||
|
||||
|
||||
start, build, import_model, list_models = (
|
||||
codegen.gen_sdk(_start),
|
||||
codegen.gen_sdk(_build),
|
||||
codegen.gen_sdk(_import_model),
|
||||
codegen.gen_sdk(_list_models),
|
||||
)
|
||||
__all__ = ['start', 'build', 'import_model', 'list_models']
|
||||
|
||||
@@ -43,15 +43,7 @@ from openllm_core.utils import (
|
||||
)
|
||||
|
||||
from . import termui
|
||||
from ._factory import (
|
||||
FC,
|
||||
_AnyCallable,
|
||||
machine_option,
|
||||
model_name_argument,
|
||||
parse_config_options,
|
||||
start_decorator,
|
||||
optimization_decorator,
|
||||
)
|
||||
from ._factory import FC, _AnyCallable, machine_option, model_name_argument, parse_config_options, start_decorator, optimization_decorator
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
@@ -65,14 +57,14 @@ else:
|
||||
|
||||
P = ParamSpec('P')
|
||||
logger = logging.getLogger('openllm')
|
||||
OPENLLM_FIGLET = '''\
|
||||
OPENLLM_FIGLET = """\
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
|
||||
'''
|
||||
"""
|
||||
|
||||
ServeCommand = t.Literal['serve', 'serve-grpc']
|
||||
|
||||
@@ -103,20 +95,12 @@ def backend_warning(backend: LiteralBackend, build: bool = False) -> None:
|
||||
'vLLM is not available. Note that PyTorch backend is not as performant as vLLM and you should always consider using vLLM for production.'
|
||||
)
|
||||
if build:
|
||||
logger.info(
|
||||
"Tip: You can set '--backend vllm' to package your Bento with vLLM backend regardless if vLLM is available locally."
|
||||
)
|
||||
logger.info("Tip: You can set '--backend vllm' to package your Bento with vLLM backend regardless if vLLM is available locally.")
|
||||
|
||||
|
||||
class Extensions(click.MultiCommand):
|
||||
def list_commands(self, ctx: click.Context) -> list[str]:
|
||||
return sorted(
|
||||
[
|
||||
filename[:-3]
|
||||
for filename in os.listdir(_EXT_FOLDER)
|
||||
if filename.endswith('.py') and not filename.startswith('__')
|
||||
]
|
||||
)
|
||||
return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith('.py') and not filename.startswith('__')])
|
||||
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||
try:
|
||||
@@ -133,41 +117,19 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
def common_params(f: t.Callable[P, t.Any]) -> t.Callable[[FC], FC]:
|
||||
# The following logics is similar to one of BentoMLCommandGroup
|
||||
@cog.optgroup.group(name='Global options', help='Shared globals options for all OpenLLM CLI.') # type: ignore[misc]
|
||||
@cog.optgroup.option('-q', '--quiet', envvar=QUIET_ENV_VAR, is_flag=True, default=False, help='Suppress all output.', show_envvar=True)
|
||||
@cog.optgroup.option(
|
||||
'-q', '--quiet', envvar=QUIET_ENV_VAR, is_flag=True, default=False, help='Suppress all output.', show_envvar=True
|
||||
'--debug', '--verbose', 'debug', envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help='Print out debug logs.', show_envvar=True
|
||||
)
|
||||
@cog.optgroup.option(
|
||||
'--debug',
|
||||
'--verbose',
|
||||
'debug',
|
||||
envvar=DEBUG_ENV_VAR,
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help='Print out debug logs.',
|
||||
show_envvar=True,
|
||||
'--do-not-track', is_flag=True, default=False, envvar=analytics.OPENLLM_DO_NOT_TRACK, help='Do not send usage info', show_envvar=True
|
||||
)
|
||||
@cog.optgroup.option(
|
||||
'--do-not-track',
|
||||
is_flag=True,
|
||||
default=False,
|
||||
envvar=analytics.OPENLLM_DO_NOT_TRACK,
|
||||
help='Do not send usage info',
|
||||
show_envvar=True,
|
||||
)
|
||||
@cog.optgroup.option(
|
||||
'--context',
|
||||
'cloud_context',
|
||||
envvar='BENTOCLOUD_CONTEXT',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='BentoCloud context name.',
|
||||
show_envvar=True,
|
||||
'--context', 'cloud_context', envvar='BENTOCLOUD_CONTEXT', type=click.STRING, default=None, help='BentoCloud context name.', show_envvar=True
|
||||
)
|
||||
@click.pass_context
|
||||
@functools.wraps(f)
|
||||
def wrapper(
|
||||
ctx: click.Context, quiet: bool, debug: bool, cloud_context: str | None, *args: P.args, **attrs: P.kwargs
|
||||
) -> t.Any:
|
||||
def wrapper(ctx: click.Context, quiet: bool, debug: bool, cloud_context: str | None, *args: P.args, **attrs: P.kwargs) -> t.Any:
|
||||
ctx.obj = GlobalOptions(cloud_context=cloud_context)
|
||||
if quiet:
|
||||
set_quiet_mode(True)
|
||||
@@ -181,9 +143,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def usage_tracking(
|
||||
func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any
|
||||
) -> t.Callable[Concatenate[bool, P], t.Any]:
|
||||
def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[Concatenate[bool, P], t.Any]:
|
||||
command_name = attrs.get('name', func.__name__)
|
||||
|
||||
@functools.wraps(func)
|
||||
@@ -242,9 +202,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
_memo = getattr(wrapped, '__click_params__', None)
|
||||
if _memo is None:
|
||||
raise ValueError('Click command not register correctly.')
|
||||
_object_setattr(
|
||||
wrapped, '__click_params__', _memo[-self.NUMBER_OF_COMMON_PARAMS :] + _memo[: -self.NUMBER_OF_COMMON_PARAMS]
|
||||
)
|
||||
_object_setattr(wrapped, '__click_params__', _memo[-self.NUMBER_OF_COMMON_PARAMS :] + _memo[: -self.NUMBER_OF_COMMON_PARAMS])
|
||||
# NOTE: we need to call super of super to avoid conflict with BentoMLCommandGroup command setup
|
||||
cmd = super(BentoMLCommandGroup, self).command(*args, **kwargs)(wrapped)
|
||||
# NOTE: add aliases to a given commands if it is specified.
|
||||
@@ -258,7 +216,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
return decorator
|
||||
|
||||
def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
||||
'''Additional format methods that include extensions as well as the default cli command.'''
|
||||
"""Additional format methods that include extensions as well as the default cli command."""
|
||||
from gettext import gettext as _
|
||||
|
||||
commands: list[tuple[str, click.Command]] = []
|
||||
@@ -305,7 +263,7 @@ _PACKAGE_NAME = 'openllm'
|
||||
message=f'{_PACKAGE_NAME}, %(version)s (compiled: {openllm.COMPILED})\nPython ({platform.python_implementation()}) {platform.python_version()}',
|
||||
)
|
||||
def cli() -> None:
|
||||
'''\b
|
||||
"""\b
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
@@ -316,15 +274,10 @@ def cli() -> None:
|
||||
\b
|
||||
An open platform for operating large language models in production.
|
||||
Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
@cli.command(
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
name='start',
|
||||
aliases=['start-http'],
|
||||
short_help='Start a LLMServer for any supported LLM.',
|
||||
)
|
||||
@cli.command(context_settings=termui.CONTEXT_SETTINGS, name='start', aliases=['start-http'], short_help='Start a LLMServer for any supported LLM.')
|
||||
@click.argument('model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True)
|
||||
@click.option(
|
||||
'--model-id',
|
||||
@@ -366,24 +319,30 @@ def start_command(
|
||||
dtype: LiteralDtype,
|
||||
deprecated_model_id: str | None,
|
||||
max_model_len: int | None,
|
||||
gpu_memory_utilization:float,
|
||||
gpu_memory_utilization: float,
|
||||
**attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
'''Start any LLM as a REST server.
|
||||
"""Start any LLM as a REST server.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm <start|start-http> <model_id> --<options> ...
|
||||
```
|
||||
'''
|
||||
if backend == 'pt': logger.warning('PyTorch backend is deprecated and will be removed in future releases. Make sure to use vLLM instead.')
|
||||
"""
|
||||
if backend == 'pt':
|
||||
logger.warning('PyTorch backend is deprecated and will be removed in future releases. Make sure to use vLLM instead.')
|
||||
if model_id in openllm.CONFIG_MAPPING:
|
||||
_model_name = model_id
|
||||
if deprecated_model_id is not None:
|
||||
model_id = deprecated_model_id
|
||||
else:
|
||||
model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
|
||||
logger.warning("Passing 'openllm start %s%s' is deprecated and will be remove in a future version. Use 'openllm start %s' instead.", _model_name, '' if deprecated_model_id is None else f' --model-id {deprecated_model_id}', model_id)
|
||||
logger.warning(
|
||||
"Passing 'openllm start %s%s' is deprecated and will be remove in a future version. Use 'openllm start %s' instead.",
|
||||
_model_name,
|
||||
'' if deprecated_model_id is None else f' --model-id {deprecated_model_id}',
|
||||
model_id,
|
||||
)
|
||||
|
||||
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
|
||||
|
||||
@@ -393,11 +352,7 @@ def start_command(
|
||||
|
||||
if serialisation == 'safetensors' and quantize is not None:
|
||||
logger.warning("'--quantize=%s' might not work with 'safetensors' serialisation format.", quantize)
|
||||
logger.warning(
|
||||
"Make sure to check out '%s' repository to see if the weights is in '%s' format if unsure.",
|
||||
model_id,
|
||||
serialisation,
|
||||
)
|
||||
logger.warning("Make sure to check out '%s' repository to see if the weights is in '%s' format if unsure.", model_id, serialisation)
|
||||
logger.info("Tip: You can always fallback to '--serialisation legacy' when running quantisation.")
|
||||
|
||||
import torch
|
||||
@@ -425,7 +380,7 @@ def start_command(
|
||||
config, server_attrs = llm.config.model_validate_click(**attrs)
|
||||
server_timeout = first_not_none(server_timeout, default=config['timeout'])
|
||||
server_attrs.update({'working_dir': pkg.source_locations('openllm'), 'timeout': server_timeout})
|
||||
development = server_attrs.pop('development') # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
development = server_attrs.pop('development') # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
server_attrs.setdefault('production', not development)
|
||||
|
||||
start_env = process_environ(
|
||||
@@ -454,26 +409,27 @@ def start_command(
|
||||
# NOTE: Return the configuration for telemetry purposes.
|
||||
return config
|
||||
|
||||
|
||||
def process_environ(config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True):
|
||||
environ = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {})
|
||||
environ.update(
|
||||
{
|
||||
'OPENLLM_MODEL_ID': model_id,
|
||||
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
|
||||
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
|
||||
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(),
|
||||
'BACKEND': llm.__llm_backend__,
|
||||
'DTYPE': str(llm._torch_dtype).split('.')[-1],
|
||||
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
|
||||
'MAX_MODEL_LEN': orjson.dumps(llm._max_model_len).decode(),
|
||||
'GPU_MEMORY_UTILIZATION': orjson.dumps(llm._gpu_memory_utilization).decode(),
|
||||
}
|
||||
)
|
||||
if llm.quantise: environ['QUANTIZE'] = str(llm.quantise)
|
||||
environ.update({
|
||||
'OPENLLM_MODEL_ID': model_id,
|
||||
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
|
||||
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
|
||||
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(),
|
||||
'BACKEND': llm.__llm_backend__,
|
||||
'DTYPE': str(llm._torch_dtype).split('.')[-1],
|
||||
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
|
||||
'MAX_MODEL_LEN': orjson.dumps(llm._max_model_len).decode(),
|
||||
'GPU_MEMORY_UTILIZATION': orjson.dumps(llm._gpu_memory_utilization).decode(),
|
||||
})
|
||||
if llm.quantise:
|
||||
environ['QUANTIZE'] = str(llm.quantise)
|
||||
return environ
|
||||
|
||||
|
||||
def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]) -> TypeGuard[float]:
|
||||
if isinstance(wpr, str):
|
||||
if wpr == 'round_robin':
|
||||
@@ -491,6 +447,7 @@ def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]
|
||||
wpr = float(wpr)
|
||||
return wpr
|
||||
|
||||
|
||||
def build_bento_instruction(llm, model_id, serialisation, adapter_map):
|
||||
cmd_name = f'openllm build {model_id} --backend {llm.__llm_backend__}'
|
||||
if llm.quantise:
|
||||
@@ -498,12 +455,9 @@ def build_bento_instruction(llm, model_id, serialisation, adapter_map):
|
||||
if llm.__llm_backend__ in {'pt', 'vllm'}:
|
||||
cmd_name += f' --serialization {serialisation}'
|
||||
if adapter_map is not None:
|
||||
cmd_name += ' ' + ' '.join(
|
||||
[
|
||||
f'--adapter-id {s}'
|
||||
for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]
|
||||
]
|
||||
)
|
||||
cmd_name += ' ' + ' '.join([
|
||||
f'--adapter-id {s}' for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]
|
||||
])
|
||||
if not openllm.utils.get_quiet_mode():
|
||||
termui.info(f"🚀Tip: run '{cmd_name}' to create a BentoLLM for '{model_id}'")
|
||||
|
||||
@@ -537,9 +491,8 @@ def run_server(args, env, return_process=False) -> subprocess.Popen[bytes] | int
|
||||
if return_process:
|
||||
return process
|
||||
stop_event = threading.Event()
|
||||
# yapf: disable
|
||||
stdout, stderr = threading.Thread(target=handle, args=(process.stdout, stop_event)), threading.Thread(target=handle, args=(process.stderr, stop_event))
|
||||
stdout.start(); stderr.start()
|
||||
stdout.start(); stderr.start() # noqa: E702
|
||||
|
||||
try:
|
||||
process.wait()
|
||||
@@ -554,10 +507,9 @@ def run_server(args, env, return_process=False) -> subprocess.Popen[bytes] | int
|
||||
raise
|
||||
finally:
|
||||
stop_event.set()
|
||||
stdout.join(); stderr.join()
|
||||
stdout.join(); stderr.join() # noqa: E702
|
||||
if process.poll() is not None: process.kill()
|
||||
stdout.join(); stderr.join()
|
||||
# yapf: disable
|
||||
stdout.join(); stderr.join() # noqa: E702
|
||||
|
||||
return process.returncode
|
||||
|
||||
@@ -645,10 +597,7 @@ def import_command(
|
||||
backend=backend,
|
||||
dtype=dtype,
|
||||
serialisation=t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy')
|
||||
),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__)
|
||||
@@ -707,21 +656,14 @@ class BuildBentoOutput(t.TypedDict):
|
||||
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
|
||||
help='Deprecated. Use positional argument instead.',
|
||||
)
|
||||
@click.option(
|
||||
'--bento-version',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Optional bento version for this BentoLLM. Default is the the model revision.',
|
||||
)
|
||||
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
|
||||
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
|
||||
@click.option(
|
||||
'--enable-features',
|
||||
multiple=True,
|
||||
nargs=1,
|
||||
metavar='FEATURE[,FEATURE]',
|
||||
help='Enable additional features for building this LLM Bento. Available: {}'.format(
|
||||
', '.join(OPTIONAL_DEPENDENCIES)
|
||||
),
|
||||
help='Enable additional features for building this LLM Bento. Available: {}'.format(', '.join(OPTIONAL_DEPENDENCIES)),
|
||||
)
|
||||
@optimization_decorator
|
||||
@click.option(
|
||||
@@ -732,12 +674,7 @@ class BuildBentoOutput(t.TypedDict):
|
||||
help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.",
|
||||
)
|
||||
@click.option('--build-ctx', help='Build context. This is required if --adapter-id uses relative path', default=None)
|
||||
@click.option(
|
||||
'--dockerfile-template',
|
||||
default=None,
|
||||
type=click.File(),
|
||||
help='Optional custom dockerfile template to be used with this BentoLLM.',
|
||||
)
|
||||
@click.option('--dockerfile-template', default=None, type=click.File(), help='Optional custom dockerfile template to be used with this BentoLLM.')
|
||||
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options') # type: ignore[misc]
|
||||
@cog.optgroup.option(
|
||||
'--containerize',
|
||||
@@ -788,13 +725,13 @@ def build_command(
|
||||
build_ctx: str | None,
|
||||
dockerfile_template: t.TextIO | None,
|
||||
max_model_len: int | None,
|
||||
gpu_memory_utilization:float,
|
||||
gpu_memory_utilization: float,
|
||||
containerize: bool,
|
||||
push: bool,
|
||||
force_push: bool,
|
||||
**_: t.Any,
|
||||
) -> BuildBentoOutput:
|
||||
'''Package a given models into a BentoLLM.
|
||||
"""Package a given models into a BentoLLM.
|
||||
|
||||
\b
|
||||
```bash
|
||||
@@ -810,7 +747,7 @@ def build_command(
|
||||
> [!IMPORTANT]
|
||||
> To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment
|
||||
> target also use the same Python version and architecture as build machine.
|
||||
'''
|
||||
"""
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
if model_id in openllm.CONFIG_MAPPING:
|
||||
@@ -840,9 +777,7 @@ def build_command(
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
serialisation=first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
serialisation=first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'),
|
||||
_eager=False,
|
||||
)
|
||||
if llm.__llm_backend__ not in llm.config['backend']:
|
||||
@@ -854,9 +789,7 @@ def build_command(
|
||||
model = openllm.serialisation.import_model(llm, trust_remote_code=llm.trust_remote_code)
|
||||
llm._tag = model.tag
|
||||
|
||||
os.environ.update(
|
||||
**process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm)
|
||||
)
|
||||
os.environ.update(**process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm))
|
||||
|
||||
try:
|
||||
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
|
||||
@@ -923,11 +856,7 @@ def build_command(
|
||||
|
||||
def get_current_bentocloud_context() -> str | None:
|
||||
try:
|
||||
context = (
|
||||
cloud_config.get_context(ctx.obj.cloud_context)
|
||||
if ctx.obj.cloud_context
|
||||
else cloud_config.get_current_context()
|
||||
)
|
||||
context = cloud_config.get_context(ctx.obj.cloud_context) if ctx.obj.cloud_context else cloud_config.get_current_context()
|
||||
return context.name
|
||||
except Exception:
|
||||
return None
|
||||
@@ -951,9 +880,7 @@ def build_command(
|
||||
tag=str(bento_tag),
|
||||
backend=llm.__llm_backend__,
|
||||
instructions=[
|
||||
DeploymentInstruction.from_content(
|
||||
type='bentocloud', instr="☁️ Push to BentoCloud with 'bentoml push':\n $ {cmd}", cmd=push_cmd
|
||||
),
|
||||
DeploymentInstruction.from_content(type='bentocloud', instr="☁️ Push to BentoCloud with 'bentoml push':\n $ {cmd}", cmd=push_cmd),
|
||||
DeploymentInstruction.from_content(
|
||||
type='container',
|
||||
instr="🐳 Container BentoLLM with 'bentoml containerize':\n $ {cmd}",
|
||||
@@ -979,9 +906,7 @@ def build_command(
|
||||
termui.echo(f" * {instruction['content']}\n", nl=False)
|
||||
|
||||
if push:
|
||||
BentoMLContainer.bentocloud_client.get().push_bento(
|
||||
bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push
|
||||
)
|
||||
BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push)
|
||||
elif containerize:
|
||||
container_backend = t.cast('DefaultBuilder', os.environ.get('BENTOML_CONTAINERIZE_BACKEND', 'docker'))
|
||||
try:
|
||||
@@ -1009,20 +934,19 @@ class ModelItem(t.TypedDict):
|
||||
@cli.command()
|
||||
@click.option('--show-available', is_flag=True, default=True, hidden=True)
|
||||
def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]:
|
||||
'''List all supported models.
|
||||
"""List all supported models.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm models
|
||||
```
|
||||
'''
|
||||
"""
|
||||
result: dict[t.LiteralString, ModelItem] = {
|
||||
m: ModelItem(
|
||||
architecture=config.__openllm_architecture__,
|
||||
example_id=random.choice(config.__openllm_model_ids__),
|
||||
supported_backends=config.__openllm_backend__,
|
||||
installation='pip install '
|
||||
+ (f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config.__openllm_requirements__ else 'openllm'),
|
||||
installation='pip install ' + (f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config.__openllm_requirements__ else 'openllm'),
|
||||
items=[
|
||||
str(md.tag)
|
||||
for md in bentoml.models.list()
|
||||
@@ -1041,13 +965,7 @@ def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]:
|
||||
@cli.command()
|
||||
@model_name_argument(required=False)
|
||||
@click.option('-y', '--yes', '--assume-yes', is_flag=True, help='Skip confirmation when deleting a specific model')
|
||||
@click.option(
|
||||
'--include-bentos/--no-include-bentos',
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
default=True,
|
||||
help='Whether to also include pruning bentos.',
|
||||
)
|
||||
@click.option('--include-bentos/--no-include-bentos', is_flag=True, hidden=True, default=True, help='Whether to also include pruning bentos.')
|
||||
@inject
|
||||
@click.pass_context
|
||||
def prune_command(
|
||||
@@ -1058,38 +976,30 @@ def prune_command(
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
**_: t.Any,
|
||||
) -> None:
|
||||
'''Remove all saved models, and bentos built with OpenLLM locally.
|
||||
"""Remove all saved models, and bentos built with OpenLLM locally.
|
||||
|
||||
\b
|
||||
If a model type is passed, then only prune models for that given model type.
|
||||
'''
|
||||
"""
|
||||
available: list[tuple[bentoml.Model | bentoml.Bento, ModelStore | BentoStore]] = [
|
||||
(m, model_store)
|
||||
for m in bentoml.models.list()
|
||||
if 'framework' in m.info.labels and m.info.labels['framework'] == 'openllm'
|
||||
(m, model_store) for m in bentoml.models.list() if 'framework' in m.info.labels and m.info.labels['framework'] == 'openllm'
|
||||
]
|
||||
if model_name is not None:
|
||||
available = [
|
||||
(m, store)
|
||||
for m, store in available
|
||||
if 'model_name' in m.info.labels and m.info.labels['model_name'] == inflection.underscore(model_name)
|
||||
(m, store) for m, store in available if 'model_name' in m.info.labels and m.info.labels['model_name'] == inflection.underscore(model_name)
|
||||
] + [
|
||||
(b, bento_store)
|
||||
for b in bentoml.bentos.list()
|
||||
if 'start_name' in b.info.labels and b.info.labels['start_name'] == inflection.underscore(model_name)
|
||||
]
|
||||
else:
|
||||
available += [
|
||||
(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels
|
||||
]
|
||||
available += [(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels]
|
||||
|
||||
for store_item, store in available:
|
||||
if yes:
|
||||
delete_confirmed = True
|
||||
else:
|
||||
delete_confirmed = click.confirm(
|
||||
f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?"
|
||||
)
|
||||
delete_confirmed = click.confirm(f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?")
|
||||
if delete_confirmed:
|
||||
store.delete(store_item.tag)
|
||||
termui.warning(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.")
|
||||
@@ -1136,17 +1046,8 @@ def shared_client_options(f: _AnyCallable | None = None) -> t.Callable[[FC], FC]
|
||||
|
||||
@cli.command()
|
||||
@shared_client_options
|
||||
@click.option(
|
||||
'--server-type',
|
||||
type=click.Choice(['grpc', 'http']),
|
||||
help='Server type',
|
||||
default='http',
|
||||
show_default=True,
|
||||
hidden=True,
|
||||
)
|
||||
@click.option(
|
||||
'--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.'
|
||||
)
|
||||
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True, hidden=True)
|
||||
@click.option('--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.')
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option(
|
||||
'--sampling-params',
|
||||
@@ -1168,14 +1069,15 @@ def query_command(
|
||||
_memoized: DictStrAny,
|
||||
**_: t.Any,
|
||||
) -> None:
|
||||
'''Query a LLM interactively, from a terminal.
|
||||
"""Query a LLM interactively, from a terminal.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?"
|
||||
```
|
||||
'''
|
||||
if server_type == 'grpc': raise click.ClickException("'grpc' is currently disabled.")
|
||||
"""
|
||||
if server_type == 'grpc':
|
||||
raise click.ClickException("'grpc' is currently disabled.")
|
||||
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
|
||||
# TODO: grpc support
|
||||
client = openllm.HTTPClient(address=endpoint, timeout=timeout)
|
||||
@@ -1194,7 +1096,7 @@ def query_command(
|
||||
|
||||
@cli.group(cls=Extensions, hidden=True, name='extension')
|
||||
def extension_command() -> None:
|
||||
'''Extension for OpenLLM CLI.'''
|
||||
"""Extension for OpenLLM CLI."""
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -21,10 +21,8 @@ if t.TYPE_CHECKING:
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(
|
||||
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> str | None:
|
||||
'''Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path).'''
|
||||
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
|
||||
@@ -17,9 +17,7 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
|
||||
@click.command(
|
||||
'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.'
|
||||
)
|
||||
@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.')
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
@inject
|
||||
|
||||
@@ -22,9 +22,7 @@ class PromptFormatter(string.Formatter):
|
||||
raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras:
|
||||
raise KeyError(f'Extra params passed: {extras}')
|
||||
@@ -58,9 +56,7 @@ class PromptTemplate:
|
||||
try:
|
||||
return self.template.format(**prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
|
||||
) from None
|
||||
raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template.") from None
|
||||
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@@ -128,21 +124,15 @@ def cli(
|
||||
if prompt_template_file and chat_template_file:
|
||||
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
|
||||
|
||||
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
|
||||
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
|
||||
)
|
||||
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys())
|
||||
if model_id in acceptable:
|
||||
logger.warning(
|
||||
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
|
||||
)
|
||||
logger.warning('Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n')
|
||||
config = openllm.AutoConfig.for_model(model_id)
|
||||
template = prompt_template_file.read() if prompt_template_file is not None else config.template
|
||||
system_message = system_message or config.system_message
|
||||
|
||||
try:
|
||||
formatted = (
|
||||
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
|
||||
)
|
||||
formatted = PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
|
||||
except RuntimeError as err:
|
||||
logger.debug('Exception caught while formatting prompt: %s', err)
|
||||
ctx.fail(str(err))
|
||||
@@ -159,21 +149,15 @@ def cli(
|
||||
for architecture in config.architectures:
|
||||
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
|
||||
system_message = (
|
||||
openllm.AutoConfig.infer_class_from_name(
|
||||
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
|
||||
)
|
||||
openllm.AutoConfig.infer_class_from_name(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture])
|
||||
.model_construct_env()
|
||||
.system_message
|
||||
)
|
||||
break
|
||||
else:
|
||||
ctx.fail(
|
||||
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
|
||||
)
|
||||
ctx.fail(f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message')
|
||||
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
|
||||
)
|
||||
formatted = tokenizer.apply_chat_template(messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False)
|
||||
|
||||
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -13,7 +13,7 @@ from openllm_cli import termui
|
||||
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context) -> None:
|
||||
'''List available bentos built by OpenLLM.'''
|
||||
"""List available bentos built by OpenLLM."""
|
||||
mapping = {
|
||||
k: [
|
||||
{
|
||||
|
||||
@@ -18,7 +18,7 @@ if t.TYPE_CHECKING:
|
||||
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False, shell_complete=model_complete_envvar)
|
||||
def cli(model_name: str | None) -> DictStrAny:
|
||||
'''List available models in lcoal store to be used wit OpenLLM.'''
|
||||
"""List available models in lcoal store to be used wit OpenLLM."""
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
@@ -33,17 +33,12 @@ def cli(model_name: str | None) -> DictStrAny:
|
||||
}
|
||||
if model_name is not None:
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
i
|
||||
for i in v
|
||||
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
|
||||
]
|
||||
k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)]
|
||||
for k, v in ids_in_local_store.items()
|
||||
}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {
|
||||
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val]
|
||||
for k, val in ids_in_local_store.items()
|
||||
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()
|
||||
}
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return local_models
|
||||
|
||||
@@ -32,14 +32,7 @@ def load_notebook_metadata() -> DictStrAny:
|
||||
|
||||
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('output-dir', default=None, required=False)
|
||||
@click.option(
|
||||
'--port',
|
||||
envvar='JUPYTER_PORT',
|
||||
show_envvar=True,
|
||||
show_default=True,
|
||||
default=8888,
|
||||
help='Default port for Jupyter server',
|
||||
)
|
||||
@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server')
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
"""OpenLLM Playground.
|
||||
@@ -60,9 +53,7 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
|
||||
"""
|
||||
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
|
||||
raise RuntimeError(
|
||||
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
|
||||
)
|
||||
raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'")
|
||||
metadata = load_notebook_metadata()
|
||||
_temp_dir = False
|
||||
if output_dir is None:
|
||||
@@ -74,9 +65,7 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
|
||||
for module in pkgutil.iter_modules(playground.__path__):
|
||||
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
|
||||
logger.debug(
|
||||
'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module'
|
||||
)
|
||||
logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module')
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder):
|
||||
continue
|
||||
@@ -86,20 +75,18 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
f.cells.insert(0, markdown_cell)
|
||||
jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook')
|
||||
try:
|
||||
subprocess.check_output(
|
||||
[
|
||||
sys.executable,
|
||||
'-m',
|
||||
'jupyter',
|
||||
'notebook',
|
||||
'--notebook-dir',
|
||||
output_dir,
|
||||
'--port',
|
||||
str(port),
|
||||
'--no-browser',
|
||||
'--debug',
|
||||
]
|
||||
)
|
||||
subprocess.check_output([
|
||||
sys.executable,
|
||||
'-m',
|
||||
'jupyter',
|
||||
'notebook',
|
||||
'--notebook-dir',
|
||||
output_dir,
|
||||
'--port',
|
||||
str(port),
|
||||
'--no-browser',
|
||||
'--debug',
|
||||
])
|
||||
except subprocess.CalledProcessError as e:
|
||||
termui.echo(e.output, fg='red')
|
||||
raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None
|
||||
|
||||
@@ -25,14 +25,7 @@ class Level(enum.IntEnum):
|
||||
|
||||
@property
|
||||
def color(self) -> str | None:
|
||||
return {
|
||||
Level.NOTSET: None,
|
||||
Level.DEBUG: 'cyan',
|
||||
Level.INFO: 'green',
|
||||
Level.WARNING: 'yellow',
|
||||
Level.ERROR: 'red',
|
||||
Level.CRITICAL: 'red',
|
||||
}[self]
|
||||
return {Level.NOTSET: None, Level.DEBUG: 'cyan', Level.INFO: 'green', Level.WARNING: 'yellow', Level.ERROR: 'red', Level.CRITICAL: 'red'}[self]
|
||||
|
||||
@classmethod
|
||||
def from_logging_level(cls, level: int) -> Level:
|
||||
@@ -82,9 +75,5 @@ def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json:
|
||||
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {
|
||||
'help_option_names': ['-h', '--help'],
|
||||
'max_content_width': COLUMNS,
|
||||
'token_normalize_func': inflection.underscore,
|
||||
}
|
||||
CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore}
|
||||
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS', 'log', 'warning', 'error', 'critical', 'debug', 'info', 'Level']
|
||||
|
||||
Reference in New Issue
Block a user