feat(models): command-r (#1005)

* feat(models): add support for command-r

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>

* feat(models): support command-r and remove deadcode and extensions

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>

* chore: update local.sh script

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2024-06-02 10:16:08 -04:00
committed by GitHub
parent 9649073713
commit bf28f977bc
28 changed files with 628 additions and 923 deletions

View File

@@ -105,7 +105,7 @@ def optimization_decorator(fn: t.Callable[..., t.Any]):
'--quantise',
'--quantize',
'quantise',
type=str,
type=click.Choice(get_literal_args(LiteralQuantise)),
default=None,
envvar='QUANTIZE',
show_envvar=True,

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import inspect, orjson, dataclasses, bentoml, functools, attr, openllm_core, traceback, openllm, typing as t
import inspect, orjson, logging, dataclasses, bentoml, functools, attr, os, openllm_core, traceback, openllm, typing as t
from openllm_core.utils import (
get_debug_mode,
@@ -10,11 +10,13 @@ from openllm_core.utils import (
dict_filter_none,
Counter,
)
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype, get_literal_args
from openllm_core._schemas import GenerationOutput
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from vllm import AsyncEngineArgs, EngineArgs, RequestOutput
@@ -30,11 +32,24 @@ def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.
def check_quantization(_, attr: attr.Attribute[LiteralQuantise], v: str | None) -> LiteralQuantise | None:
if v is not None and v not in {'gptq', 'awq', 'squeezellm'}:
if v is not None and v not in get_literal_args(LiteralQuantise):
raise ValueError(f'Invalid quantization method: {v}')
return v
def update_engine_args(v: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
env_json_string = os.environ.get('ENGINE_CONFIG', None)
config_from_env = {}
if env_json_string is not None:
try:
config_from_env = orjson.loads(env_json_string)
except orjson.JSONDecodeError as e:
raise RuntimeError("Failed to parse 'ENGINE_CONFIG' as valid JSON string.") from e
config_from_env.update(v)
return config_from_env
@attr.define(init=False)
class LLM:
model_id: str
@@ -44,7 +59,7 @@ class LLM:
dtype: Dtype
quantise: t.Optional[LiteralQuantise] = attr.field(default=None, validator=check_quantization)
trust_remote_code: bool = attr.field(default=False)
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args, converter=update_engine_args)
_mode: t.Literal['batch', 'async'] = attr.field(default='async', repr=False)
_path: str = attr.field(
@@ -117,18 +132,27 @@ class LLM:
num_gpus, dev = 1, openllm.utils.device_count()
if dev >= 2:
num_gpus = min(dev // 2 * 2, dev)
dtype = 'float16' if self.quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
self.engine_args.update({
'worker_use_ray': False,
'tokenizer_mode': 'auto',
overriden_dict = {
'tensor_parallel_size': num_gpus,
'model': self._path,
'tokenizer': self._path,
'trust_remote_code': self.trust_remote_code,
'dtype': dtype,
'dtype': self.dtype,
'quantization': self.quantise,
})
}
if any(k in self.engine_args for k in overriden_dict.keys()):
logger.warning(
'The following key will be overriden by openllm: %s (got %s set)',
list(overriden_dict),
[k for k in overriden_dict if k in self.engine_args],
)
self.engine_args.update(overriden_dict)
if 'worker_use_ray' not in self.engine_args:
self.engine_args['worker_use_ray'] = False
if 'tokenizer_mode' not in self.engine_args:
self.engine_args['tokenizer_mode'] = 'auto'
if 'disable_log_stats' not in self.engine_args:
self.engine_args['disable_log_stats'] = not get_debug_mode()
if 'gpu_memory_utilization' not in self.engine_args: