diff --git a/ADDING_NEW_MODEL.md b/ADDING_NEW_MODEL.md index c31a9ddb..5084a011 100644 --- a/ADDING_NEW_MODEL.md +++ b/ADDING_NEW_MODEL.md @@ -15,6 +15,7 @@ Here's your roadmap: `$GIT_ROOT/openllm-core/src/openllm_core/config/configuration_{model_name}.py` - [ ] Update `$GIT_ROOT/openllm-core/src/openllm_core/config/__init__.py` to import the new model - [ ] Add your new model entry in `$GIT_ROOT/openllm-core/src/openllm_core/config/configuration_auto.py` with a tuple of the `model_name` alongside with the `ModelConfig` +- [ ] Run `./tools/update-config-stubs.py` > [!NOTE] > diff --git a/openllm-core/src/openllm_core/config/__init__.py b/openllm-core/src/openllm_core/config/__init__.py index 0f91ae2f..0af667b0 100644 --- a/openllm-core/src/openllm_core/config/__init__.py +++ b/openllm-core/src/openllm_core/config/__init__.py @@ -47,3 +47,4 @@ from .configuration_starcoder import ( START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING, StarCoderConfig as StarCoderConfig, ) +from .configuration_yi import START_YI_COMMAND_DOCSTRING as START_YI_COMMAND_DOCSTRING, YiConfig as YiConfig diff --git a/openllm-core/src/openllm_core/config/configuration_auto.py b/openllm-core/src/openllm_core/config/configuration_auto.py index 6c9b460d..66db071f 100644 --- a/openllm-core/src/openllm_core/config/configuration_auto.py +++ b/openllm-core/src/openllm_core/config/configuration_auto.py @@ -40,6 +40,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ('stablelm', 'StableLMConfig'), ('starcoder', 'StarCoderConfig'), ('mistral', 'MistralConfig'), + ('yi', 'YiConfig'), ('baichuan', 'BaichuanConfig'), ] ) @@ -151,6 +152,9 @@ class AutoConfig: def for_model(cls,model_name:t.Literal['mistral'],**attrs:t.Any)->openllm_core.config.MistralConfig:... @t.overload @classmethod + def for_model(cls,model_name:t.Literal['yi'],**attrs:t.Any)->openllm_core.config.YiConfig:... + @t.overload + @classmethod def for_model(cls,model_name:t.Literal['baichuan'],**attrs:t.Any)->openllm_core.config.BaichuanConfig:... # update-config-stubs.py: auto stubs stop # fmt: on diff --git a/openllm-core/src/openllm_core/config/configuration_mistral.py b/openllm-core/src/openllm_core/config/configuration_mistral.py index aff382d7..1f444f97 100644 --- a/openllm-core/src/openllm_core/config/configuration_mistral.py +++ b/openllm-core/src/openllm_core/config/configuration_mistral.py @@ -8,7 +8,7 @@ START_MISTRAL_COMMAND_DOCSTRING = """\ Run a LLMServer for Mistral model. \b -> See more information about Mistral at [Mistral's model card](https://huggingface.co/docs/transformers/main/model_doc/mistral +> See more information about Mistral at [Mistral's model card](https://huggingface.co/docs/transformers/main/model_doc/mistral) \b ## Usage diff --git a/openllm-core/src/openllm_core/config/configuration_yi.py b/openllm-core/src/openllm_core/config/configuration_yi.py new file mode 100644 index 00000000..574dcf7d --- /dev/null +++ b/openllm-core/src/openllm_core/config/configuration_yi.py @@ -0,0 +1,100 @@ +from __future__ import annotations +import typing as t + +import openllm_core +from openllm_core.prompts import PromptTemplate + +START_YI_COMMAND_DOCSTRING = """\ +Run a LLMServer for Yi model. + +\b +> See more information about Yi at [Yi's GitHub](https://github.com/01-ai/Yi) + +\b +## Usage + +By default, this model will use [vLLM](https://github.com/vllm-project/vllm) for inference. +This model will also supports PyTorch. Note that this models requires to trust remote code, +meaning you need to set ``TRUST_REMOTE_CODE=True`` before running the model. + +\b +- To use PyTorch, set the environment variable ``OPENLLM_BACKEND="pt"`` + +\b +- To use vLLM, set the environment variable ``OPENLLM_BACKEND="vllm"`` + +\b +Yi Runner will use 01-ai/Yi-6B, as the default model. +To change to any other Yi saved pretrained, or a fine-tune Yi, +provide ``OPENLLM_MODEL_ID='01-ai/Yi-6B'`` or provide the model_id running ``openllm start``: + +\b +$ TRUST_REMOTE_CODE=True openllm start 01-ai/Yi-6B +""" + +DEFAULT_SYSTEM_MESSAGE = '' +DEFAULT_PROMPT_TEMPLATE = '{instruction}' + + +class YiConfig(openllm_core.LLMConfig): + """The Yi series models are large language models trained from scratch by developers at 01.AI. + + The first public release contains two bilingual(English/Chinese) base models with the parameter sizes of 6B(Yi-6B) and 34B(Yi-34B). + Both of them are trained with 4K sequence length and can be extended to 32K during inference time. The Yi-6B-200K and Yi-34B-200K are base model with 200K context length. + + See [Yi's GitHub](https://github.com/01-ai/Yi) for more information. + """ + + __config__ = { + 'name_type': 'lowercase', + 'url': 'https://01.ai/', + 'architecture': 'YiForCausalLM', + 'trust_remote_code': True, + 'default_id': '01-ai/Yi-6B', + 'serialisation': 'safetensors', + 'model_ids': ['01-ai/Yi-6B', '01-ai/Yi-34B', '01-ai/Yi-6B-200K', '01-ai/Yi-34B-200K'], + } + + class GenerationConfig: + max_new_tokens: int = 256 + temperature: float = 0.7 + repetition_penalty: float = 1.3 + no_repeat_ngram_size: int = 5 + top_p: float = 0.9 + top_k: int = 40 + + class SamplingParams: + best_of: int = 1 + presence_penalty: float = 0.5 + + @property + def default_prompt_template(self) -> str: + return DEFAULT_PROMPT_TEMPLATE + + @property + def default_system_message(self) -> str: + return DEFAULT_SYSTEM_MESSAGE + + def sanitize_parameters( + self, + prompt: str, + top_k: int, + top_p: float, + temperature: float, + max_new_tokens: int, + repetition_penalty: float, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, + **attrs: t.Any, + ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message + if prompt_template is None: + prompt_template = DEFAULT_PROMPT_TEMPLATE + elif isinstance(prompt_template, str): + prompt_template = PromptTemplate(template=prompt_template) + + return ( + prompt_template.with_options(system_message=system_message).format(instruction=prompt), + {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k}, + {}, + ) diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 7c10da7f..b039c65b 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -36,9 +36,9 @@ from openllm_core.exceptions import MissingDependencyError from openllm_core.prompts import PromptTemplate from openllm_core.utils import ( DEBUG, + ENV_VARS_TRUE_VALUES, ReprMixin, apply, - check_bool_env, codegen, converter, first_not_none, @@ -326,7 +326,10 @@ class LLM(t.Generic[M, T], ReprMixin): @property def trust_remote_code(self) -> bool: - return first_not_none(check_bool_env('TRUST_REMOTE_CODE', False), default=self.__llm_trust_remote_code__) + env = os.getenv('TRUST_REMOTE_CODE') + if env is not None: + return str(env).upper() in ENV_VARS_TRUE_VALUES + return self.__llm_trust_remote_code__ @property def runner_name(self) -> str: diff --git a/openllm-python/src/openllm/bundle/_package.py b/openllm-python/src/openllm/bundle/_package.py index f689fbcb..e6e5d666 100644 --- a/openllm-python/src/openllm/bundle/_package.py +++ b/openllm-python/src/openllm/bundle/_package.py @@ -134,6 +134,7 @@ def construct_docker_options( 'BENTOML_DEBUG': str(True), 'BENTOML_QUIET': str(False), 'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'", + 'TRUST_REMOTE_CODE': str(llm.trust_remote_code), } if adapter_map: env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1') diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index fe5a0783..661f93bb 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -277,7 +277,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): def list_commands(self, ctx: click.Context) -> list[str]: return super().list_commands(ctx) + t.cast('Extensions', extension_command).list_commands(ctx) - def command(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Callable[..., t.Any]], click.Command]: # type: ignore[override] # XXX: fix decorator on BentoMLCommandGroup + def command(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Callable[..., t.Any]], click.Command]: """Override the default 'cli.command' with supports for aliases for given command, and it wraps the implementation with common parameters.""" if 'context_settings' not in kwargs: kwargs['context_settings'] = {} @@ -457,7 +457,6 @@ def start_command( quantize=quantize, serialisation=serialisation, torch_dtype=dtype, - trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'), ) backend_warning(llm.__llm_backend__) @@ -635,6 +634,7 @@ def process_environ( 'OPENLLM_BACKEND': llm.__llm_backend__, 'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(), 'TORCH_DTYPE': str(llm._torch_dtype).split('.')[-1], + 'TRUST_REMOTE_CODE': str(llm.trust_remote_code), } ) if llm.quantise: