feat(models): command-r (#1005)

* feat(models): add support for command-r

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>

* feat(models): support command-r and remove deadcode and extensions

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>

* chore: update local.sh script

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2024-06-02 10:16:08 -04:00
committed by GitHub
parent 9649073713
commit bf28f977bc
28 changed files with 628 additions and 923 deletions

View File

@@ -105,7 +105,7 @@ def optimization_decorator(fn: t.Callable[..., t.Any]):
'--quantise',
'--quantize',
'quantise',
type=str,
type=click.Choice(get_literal_args(LiteralQuantise)),
default=None,
envvar='QUANTIZE',
show_envvar=True,

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import inspect, orjson, dataclasses, bentoml, functools, attr, openllm_core, traceback, openllm, typing as t
import inspect, orjson, logging, dataclasses, bentoml, functools, attr, os, openllm_core, traceback, openllm, typing as t
from openllm_core.utils import (
get_debug_mode,
@@ -10,11 +10,13 @@ from openllm_core.utils import (
dict_filter_none,
Counter,
)
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype, get_literal_args
from openllm_core._schemas import GenerationOutput
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from vllm import AsyncEngineArgs, EngineArgs, RequestOutput
@@ -30,11 +32,24 @@ def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.
def check_quantization(_, attr: attr.Attribute[LiteralQuantise], v: str | None) -> LiteralQuantise | None:
if v is not None and v not in {'gptq', 'awq', 'squeezellm'}:
if v is not None and v not in get_literal_args(LiteralQuantise):
raise ValueError(f'Invalid quantization method: {v}')
return v
def update_engine_args(v: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
env_json_string = os.environ.get('ENGINE_CONFIG', None)
config_from_env = {}
if env_json_string is not None:
try:
config_from_env = orjson.loads(env_json_string)
except orjson.JSONDecodeError as e:
raise RuntimeError("Failed to parse 'ENGINE_CONFIG' as valid JSON string.") from e
config_from_env.update(v)
return config_from_env
@attr.define(init=False)
class LLM:
model_id: str
@@ -44,7 +59,7 @@ class LLM:
dtype: Dtype
quantise: t.Optional[LiteralQuantise] = attr.field(default=None, validator=check_quantization)
trust_remote_code: bool = attr.field(default=False)
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args, converter=update_engine_args)
_mode: t.Literal['batch', 'async'] = attr.field(default='async', repr=False)
_path: str = attr.field(
@@ -117,18 +132,27 @@ class LLM:
num_gpus, dev = 1, openllm.utils.device_count()
if dev >= 2:
num_gpus = min(dev // 2 * 2, dev)
dtype = 'float16' if self.quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
self.engine_args.update({
'worker_use_ray': False,
'tokenizer_mode': 'auto',
overriden_dict = {
'tensor_parallel_size': num_gpus,
'model': self._path,
'tokenizer': self._path,
'trust_remote_code': self.trust_remote_code,
'dtype': dtype,
'dtype': self.dtype,
'quantization': self.quantise,
})
}
if any(k in self.engine_args for k in overriden_dict.keys()):
logger.warning(
'The following key will be overriden by openllm: %s (got %s set)',
list(overriden_dict),
[k for k in overriden_dict if k in self.engine_args],
)
self.engine_args.update(overriden_dict)
if 'worker_use_ray' not in self.engine_args:
self.engine_args['worker_use_ray'] = False
if 'tokenizer_mode' not in self.engine_args:
self.engine_args['tokenizer_mode'] = 'auto'
if 'disable_log_stats' not in self.engine_args:
self.engine_args['disable_log_stats'] = not get_debug_mode()
if 'gpu_memory_utilization' not in self.engine_args:

View File

@@ -13,7 +13,7 @@ Fine-tune, serve, deploy, and monitor any LLMs with ease.
# fmt: off
# update-config-stubs.py: import stubs start
from openllm_client import AsyncHTTPClient as AsyncHTTPClient, HTTPClient as HTTPClient
from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DbrxConfig as DbrxConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, GemmaConfig as GemmaConfig, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MistralConfig as MistralConfig, MixtralConfig as MixtralConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, PhiConfig as PhiConfig, QwenConfig as QwenConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, YiConfig as YiConfig
from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, CohereConfig as CohereConfig, DbrxConfig as DbrxConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, GemmaConfig as GemmaConfig, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MistralConfig as MistralConfig, MixtralConfig as MixtralConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, PhiConfig as PhiConfig, QwenConfig as QwenConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, YiConfig as YiConfig
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig
from openllm_core._schemas import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, MetadataOutput as MetadataOutput, MessageParam as MessageParam
from openllm_core.utils import api as api

View File

@@ -6,7 +6,6 @@ from openllm_core.utils import (
DEV_DEBUG_VAR as DEV_DEBUG_VAR,
ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES,
MYPY as MYPY,
OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES,
QUIET_ENV_VAR as QUIET_ENV_VAR,
SHOW_CODEGEN as SHOW_CODEGEN,
LazyLoader as LazyLoader,