mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-07 22:33:28 -05:00
feat(models): command-r (#1005)
* feat(models): add support for command-r Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * feat(models): support command-r and remove deadcode and extensions Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update local.sh script Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -105,7 +105,7 @@ def optimization_decorator(fn: t.Callable[..., t.Any]):
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantise',
|
||||
type=str,
|
||||
type=click.Choice(get_literal_args(LiteralQuantise)),
|
||||
default=None,
|
||||
envvar='QUANTIZE',
|
||||
show_envvar=True,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect, orjson, dataclasses, bentoml, functools, attr, openllm_core, traceback, openllm, typing as t
|
||||
import inspect, orjson, logging, dataclasses, bentoml, functools, attr, os, openllm_core, traceback, openllm, typing as t
|
||||
|
||||
from openllm_core.utils import (
|
||||
get_debug_mode,
|
||||
@@ -10,11 +10,13 @@ from openllm_core.utils import (
|
||||
dict_filter_none,
|
||||
Counter,
|
||||
)
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype, get_literal_args
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
|
||||
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from vllm import AsyncEngineArgs, EngineArgs, RequestOutput
|
||||
|
||||
@@ -30,11 +32,24 @@ def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.
|
||||
|
||||
|
||||
def check_quantization(_, attr: attr.Attribute[LiteralQuantise], v: str | None) -> LiteralQuantise | None:
|
||||
if v is not None and v not in {'gptq', 'awq', 'squeezellm'}:
|
||||
if v is not None and v not in get_literal_args(LiteralQuantise):
|
||||
raise ValueError(f'Invalid quantization method: {v}')
|
||||
return v
|
||||
|
||||
|
||||
def update_engine_args(v: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
|
||||
env_json_string = os.environ.get('ENGINE_CONFIG', None)
|
||||
|
||||
config_from_env = {}
|
||||
if env_json_string is not None:
|
||||
try:
|
||||
config_from_env = orjson.loads(env_json_string)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise RuntimeError("Failed to parse 'ENGINE_CONFIG' as valid JSON string.") from e
|
||||
config_from_env.update(v)
|
||||
return config_from_env
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class LLM:
|
||||
model_id: str
|
||||
@@ -44,7 +59,7 @@ class LLM:
|
||||
dtype: Dtype
|
||||
quantise: t.Optional[LiteralQuantise] = attr.field(default=None, validator=check_quantization)
|
||||
trust_remote_code: bool = attr.field(default=False)
|
||||
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
|
||||
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args, converter=update_engine_args)
|
||||
|
||||
_mode: t.Literal['batch', 'async'] = attr.field(default='async', repr=False)
|
||||
_path: str = attr.field(
|
||||
@@ -117,18 +132,27 @@ class LLM:
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
dtype = 'float16' if self.quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
|
||||
|
||||
self.engine_args.update({
|
||||
'worker_use_ray': False,
|
||||
'tokenizer_mode': 'auto',
|
||||
overriden_dict = {
|
||||
'tensor_parallel_size': num_gpus,
|
||||
'model': self._path,
|
||||
'tokenizer': self._path,
|
||||
'trust_remote_code': self.trust_remote_code,
|
||||
'dtype': dtype,
|
||||
'dtype': self.dtype,
|
||||
'quantization': self.quantise,
|
||||
})
|
||||
}
|
||||
if any(k in self.engine_args for k in overriden_dict.keys()):
|
||||
logger.warning(
|
||||
'The following key will be overriden by openllm: %s (got %s set)',
|
||||
list(overriden_dict),
|
||||
[k for k in overriden_dict if k in self.engine_args],
|
||||
)
|
||||
|
||||
self.engine_args.update(overriden_dict)
|
||||
if 'worker_use_ray' not in self.engine_args:
|
||||
self.engine_args['worker_use_ray'] = False
|
||||
if 'tokenizer_mode' not in self.engine_args:
|
||||
self.engine_args['tokenizer_mode'] = 'auto'
|
||||
if 'disable_log_stats' not in self.engine_args:
|
||||
self.engine_args['disable_log_stats'] = not get_debug_mode()
|
||||
if 'gpu_memory_utilization' not in self.engine_args:
|
||||
|
||||
Reference in New Issue
Block a user