mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-21 15:39:36 -04:00
feat(models): command-r (#1005)
* feat(models): add support for command-r Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * feat(models): support command-r and remove deadcode and extensions Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update local.sh script Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -105,7 +105,7 @@ def optimization_decorator(fn: t.Callable[..., t.Any]):
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantise',
|
||||
type=str,
|
||||
type=click.Choice(get_literal_args(LiteralQuantise)),
|
||||
default=None,
|
||||
envvar='QUANTIZE',
|
||||
show_envvar=True,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect, orjson, dataclasses, bentoml, functools, attr, openllm_core, traceback, openllm, typing as t
|
||||
import inspect, orjson, logging, dataclasses, bentoml, functools, attr, os, openllm_core, traceback, openllm, typing as t
|
||||
|
||||
from openllm_core.utils import (
|
||||
get_debug_mode,
|
||||
@@ -10,11 +10,13 @@ from openllm_core.utils import (
|
||||
dict_filter_none,
|
||||
Counter,
|
||||
)
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype, get_literal_args
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
|
||||
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from vllm import AsyncEngineArgs, EngineArgs, RequestOutput
|
||||
|
||||
@@ -30,11 +32,24 @@ def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.
|
||||
|
||||
|
||||
def check_quantization(_, attr: attr.Attribute[LiteralQuantise], v: str | None) -> LiteralQuantise | None:
|
||||
if v is not None and v not in {'gptq', 'awq', 'squeezellm'}:
|
||||
if v is not None and v not in get_literal_args(LiteralQuantise):
|
||||
raise ValueError(f'Invalid quantization method: {v}')
|
||||
return v
|
||||
|
||||
|
||||
def update_engine_args(v: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
|
||||
env_json_string = os.environ.get('ENGINE_CONFIG', None)
|
||||
|
||||
config_from_env = {}
|
||||
if env_json_string is not None:
|
||||
try:
|
||||
config_from_env = orjson.loads(env_json_string)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise RuntimeError("Failed to parse 'ENGINE_CONFIG' as valid JSON string.") from e
|
||||
config_from_env.update(v)
|
||||
return config_from_env
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class LLM:
|
||||
model_id: str
|
||||
@@ -44,7 +59,7 @@ class LLM:
|
||||
dtype: Dtype
|
||||
quantise: t.Optional[LiteralQuantise] = attr.field(default=None, validator=check_quantization)
|
||||
trust_remote_code: bool = attr.field(default=False)
|
||||
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
|
||||
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args, converter=update_engine_args)
|
||||
|
||||
_mode: t.Literal['batch', 'async'] = attr.field(default='async', repr=False)
|
||||
_path: str = attr.field(
|
||||
@@ -117,18 +132,27 @@ class LLM:
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
dtype = 'float16' if self.quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
|
||||
|
||||
self.engine_args.update({
|
||||
'worker_use_ray': False,
|
||||
'tokenizer_mode': 'auto',
|
||||
overriden_dict = {
|
||||
'tensor_parallel_size': num_gpus,
|
||||
'model': self._path,
|
||||
'tokenizer': self._path,
|
||||
'trust_remote_code': self.trust_remote_code,
|
||||
'dtype': dtype,
|
||||
'dtype': self.dtype,
|
||||
'quantization': self.quantise,
|
||||
})
|
||||
}
|
||||
if any(k in self.engine_args for k in overriden_dict.keys()):
|
||||
logger.warning(
|
||||
'The following key will be overriden by openllm: %s (got %s set)',
|
||||
list(overriden_dict),
|
||||
[k for k in overriden_dict if k in self.engine_args],
|
||||
)
|
||||
|
||||
self.engine_args.update(overriden_dict)
|
||||
if 'worker_use_ray' not in self.engine_args:
|
||||
self.engine_args['worker_use_ray'] = False
|
||||
if 'tokenizer_mode' not in self.engine_args:
|
||||
self.engine_args['tokenizer_mode'] = 'auto'
|
||||
if 'disable_log_stats' not in self.engine_args:
|
||||
self.engine_args['disable_log_stats'] = not get_debug_mode()
|
||||
if 'gpu_memory_utilization' not in self.engine_args:
|
||||
|
||||
@@ -13,7 +13,7 @@ Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
# fmt: off
|
||||
# update-config-stubs.py: import stubs start
|
||||
from openllm_client import AsyncHTTPClient as AsyncHTTPClient, HTTPClient as HTTPClient
|
||||
from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DbrxConfig as DbrxConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, GemmaConfig as GemmaConfig, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MistralConfig as MistralConfig, MixtralConfig as MixtralConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, PhiConfig as PhiConfig, QwenConfig as QwenConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, YiConfig as YiConfig
|
||||
from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, CohereConfig as CohereConfig, DbrxConfig as DbrxConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, GemmaConfig as GemmaConfig, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MistralConfig as MistralConfig, MixtralConfig as MixtralConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, PhiConfig as PhiConfig, QwenConfig as QwenConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, YiConfig as YiConfig
|
||||
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig
|
||||
from openllm_core._schemas import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, MetadataOutput as MetadataOutput, MessageParam as MessageParam
|
||||
from openllm_core.utils import api as api
|
||||
|
||||
@@ -6,7 +6,6 @@ from openllm_core.utils import (
|
||||
DEV_DEBUG_VAR as DEV_DEBUG_VAR,
|
||||
ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES,
|
||||
MYPY as MYPY,
|
||||
OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES,
|
||||
QUIET_ENV_VAR as QUIET_ENV_VAR,
|
||||
SHOW_CODEGEN as SHOW_CODEGEN,
|
||||
LazyLoader as LazyLoader,
|
||||
|
||||
Reference in New Issue
Block a user