mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-16 12:39:16 -04:00
chore: cleanup unused prompt templates (#713)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -31,8 +31,6 @@ class Metadata(_SchemaMixin):
|
||||
model_name: str
|
||||
backend: str
|
||||
configuration: t.Dict[str, t.Any]
|
||||
prompt_template: t.Optional[str]
|
||||
system_message: t.Optional[str]
|
||||
|
||||
|
||||
def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metadata:
|
||||
@@ -49,8 +47,6 @@ def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metada
|
||||
model_name=data['model_name'],
|
||||
backend=data['backend'],
|
||||
configuration=configuration,
|
||||
prompt_template=data['prompt_template'],
|
||||
system_message=data['system_message'],
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None
|
||||
|
||||
@@ -26,4 +26,3 @@ from .config import (
|
||||
StableLMConfig as StableLMConfig,
|
||||
StarCoderConfig as StarCoderConfig,
|
||||
)
|
||||
from .prompts import PromptTemplate as PromptTemplate
|
||||
|
||||
@@ -1460,13 +1460,13 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
|
||||
|
||||
def inference_options(self, llm: openllm.LLM[M, T], backend: str | None = None) -> tuple[Self, t.Any]:
|
||||
backend = backend if backend is not None else llm.__llm_backend__
|
||||
cls = getattr(self, backend, None)
|
||||
if cls is None:
|
||||
framework = getattr(self, backend, None)
|
||||
if framework is None:
|
||||
raise ValueError(f'Unknown backend {backend}')
|
||||
try:
|
||||
return self, cls.build(self)
|
||||
return self, framework.build(self)
|
||||
except AttributeError:
|
||||
raise RuntimeError(f'Unknown backend {llm.__llm_backend__}') from None
|
||||
raise RuntimeError(f'Unknown backend {backend}') from None
|
||||
|
||||
class vllm:
|
||||
@staticmethod
|
||||
@@ -1503,73 +1503,78 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
|
||||
def build(config: LLMConfig) -> transformers.GenerationConfig:
|
||||
return transformers.GenerationConfig(**converter.unstructure(config.generation_config))
|
||||
|
||||
def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny:
|
||||
warnings.warn(
|
||||
"'to_generation_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3
|
||||
)
|
||||
_, config = self.inference_options(None, 'hf')
|
||||
return config.to_dict() if return_as_dict else config
|
||||
|
||||
def to_sampling_config(self) -> vllm.SamplingParams:
|
||||
warnings.warn(
|
||||
"'to_sampling_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3
|
||||
)
|
||||
return self.inference_options(None, 'vllm')[-1]
|
||||
@overload
|
||||
def compatible_options(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: ...
|
||||
|
||||
@overload
|
||||
def with_request(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: ...
|
||||
def compatible_options(self, request: CohereChatRequest | CohereGenerateRequest) -> dict[str, t.Any]: ...
|
||||
|
||||
@overload
|
||||
def with_request(self, request: CohereChatRequest | CohereGenerateRequest) -> dict[str, t.Any]: ...
|
||||
|
||||
def with_request(self, request: AttrsInstance) -> dict[str, t.Any]:
|
||||
def compatible_options(self, request: AttrsInstance) -> dict[str, t.Any]:
|
||||
if importlib.util.find_spec('openllm') is None:
|
||||
raise MissingDependencyError(
|
||||
"'openllm' is required to use 'with_request'. Make sure to install with 'pip install openllm'."
|
||||
"'openllm' is required to use 'compatible_options'. Make sure to install with 'pip install openllm'."
|
||||
)
|
||||
from openllm.protocol.cohere import CohereChatRequest, CohereGenerateRequest
|
||||
from openllm.protocol.openai import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
if isinstance(request, (ChatCompletionRequest, CompletionRequest)):
|
||||
return self._with_openai_request(request)
|
||||
return self.openai.build(self, request)
|
||||
elif isinstance(request, (CohereChatRequest, CohereGenerateRequest)):
|
||||
return self._with_cohere_request(request)
|
||||
return self.cohere.build(self, request)
|
||||
else:
|
||||
raise TypeError(f'Unknown request type {type(request)}')
|
||||
|
||||
def _with_openai_request(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]:
|
||||
d = dict(
|
||||
temperature=first_not_none(request.temperature, self['temperature']),
|
||||
top_p=first_not_none(request.top_p, self['top_p']),
|
||||
top_k=first_not_none(request.top_k, self['top_k']),
|
||||
best_of=first_not_none(request.best_of, self['best_of']),
|
||||
n=first_not_none(request.n, default=self['n']),
|
||||
stop=first_not_none(request.stop, default=None),
|
||||
max_new_tokens=first_not_none(request.max_tokens, default=self['max_new_tokens']),
|
||||
presence_penalty=first_not_none(request.presence_penalty, default=self['presence_penalty']),
|
||||
frequency_penalty=first_not_none(request.frequency_penalty, default=self['frequency_penalty']),
|
||||
)
|
||||
if hasattr(request, 'logprobs'):
|
||||
d['logprobs'] = first_not_none(request.logprobs, default=self['logprobs'])
|
||||
return d
|
||||
class openai:
|
||||
@staticmethod
|
||||
def build(config: LLMConfig, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]:
|
||||
d = dict(
|
||||
temperature=first_not_none(request.temperature, config['temperature']),
|
||||
top_p=first_not_none(request.top_p, config['top_p']),
|
||||
top_k=first_not_none(request.top_k, config['top_k']),
|
||||
best_of=first_not_none(request.best_of, config['best_of']),
|
||||
n=first_not_none(request.n, default=config['n']),
|
||||
stop=first_not_none(request.stop, default=None),
|
||||
max_new_tokens=first_not_none(request.max_tokens, default=config['max_new_tokens']),
|
||||
presence_penalty=first_not_none(request.presence_penalty, default=config['presence_penalty']),
|
||||
frequency_penalty=first_not_none(request.frequency_penalty, default=config['frequency_penalty']),
|
||||
)
|
||||
if hasattr(request, 'logprobs'):
|
||||
d['logprobs'] = first_not_none(request.logprobs, default=config['logprobs'])
|
||||
return d
|
||||
|
||||
def _with_cohere_request(self, request: CohereGenerateRequest | CohereChatRequest) -> dict[str, t.Any]:
|
||||
d = dict(
|
||||
max_new_tokens=first_not_none(request.max_tokens, default=self['max_new_tokens']),
|
||||
temperature=first_not_none(request.temperature, default=self['temperature']),
|
||||
top_k=first_not_none(request.k, default=self['top_k']),
|
||||
top_p=first_not_none(request.p, default=self['top_p']),
|
||||
)
|
||||
if hasattr(request, 'num_generations'):
|
||||
d['n'] = first_not_none(request.num_generations, default=self['n'])
|
||||
if hasattr(request, 'frequency_penalty'):
|
||||
d['frequency_penalty'] = first_not_none(request.frequency_penalty, default=self['frequency_penalty'])
|
||||
if hasattr(request, 'presence_penalty'):
|
||||
d['presence_penalty'] = first_not_none(request.presence_penalty, default=self['presence_penalty'])
|
||||
return d
|
||||
class cohere:
|
||||
@staticmethod
|
||||
def build(config: LLMConfig, request: CohereGenerateRequest | CohereChatRequest) -> dict[str, t.Any]:
|
||||
d = dict(
|
||||
max_new_tokens=first_not_none(request.max_tokens, default=config['max_new_tokens']),
|
||||
temperature=first_not_none(request.temperature, default=config['temperature']),
|
||||
top_k=first_not_none(request.k, default=config['top_k']),
|
||||
top_p=first_not_none(request.p, default=config['top_p']),
|
||||
)
|
||||
if hasattr(request, 'num_generations'):
|
||||
d['n'] = first_not_none(request.num_generations, default=config['n'])
|
||||
if hasattr(request, 'frequency_penalty'):
|
||||
d['frequency_penalty'] = first_not_none(request.frequency_penalty, default=config['frequency_penalty'])
|
||||
if hasattr(request, 'presence_penalty'):
|
||||
d['presence_penalty'] = first_not_none(request.presence_penalty, default=config['presence_penalty'])
|
||||
return d
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
# Return the default prompt templates for given models.
|
||||
# This is probably only useful for debugging. Users should be responsible
|
||||
# for formatting the prompt themselves.
|
||||
# by default it is just a '{instruction}'
|
||||
return '{system_message}{instruction}'
|
||||
|
||||
@property
|
||||
def system_message(self) -> str:
|
||||
# Return the default system message for given models.
|
||||
# This should really only be used for chat models.
|
||||
return ''
|
||||
|
||||
@classmethod
|
||||
def to_click_options(cls, f: AnyCallable) -> click.Command:
|
||||
def parse(cls, f: AnyCallable) -> click.Command:
|
||||
"""Convert current configuration to click options.
|
||||
|
||||
This can be used as a decorator for click commands.
|
||||
@@ -1614,6 +1619,20 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
|
||||
def peft_task_type(cls) -> str:
|
||||
return PEFT_TASK_TYPE_TARGET_MAPPING[cls.__openllm_model_type__]
|
||||
|
||||
# deprecated
|
||||
def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny:
|
||||
warnings.warn(
|
||||
"'to_generation_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3
|
||||
)
|
||||
_, config = self.inference_options(None, 'hf')
|
||||
return config.to_dict() if return_as_dict else config
|
||||
|
||||
def to_sampling_config(self) -> vllm.SamplingParams:
|
||||
warnings.warn(
|
||||
"'to_sampling_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3
|
||||
)
|
||||
return self.inference_options(None, 'vllm')[-1]
|
||||
|
||||
|
||||
converter.register_unstructure_hook_factory(
|
||||
lambda cls: lenient_issubclass(cls, LLMConfig),
|
||||
|
||||
@@ -34,8 +34,6 @@ class MetadataOutput(_SchemaMixin):
|
||||
model_name: str
|
||||
backend: str
|
||||
configuration: str
|
||||
prompt_template: t.Optional[str]
|
||||
system_message: t.Optional[str]
|
||||
|
||||
def model_dump(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
@@ -44,8 +42,6 @@ class MetadataOutput(_SchemaMixin):
|
||||
'model_name': self.model_name,
|
||||
'backend': self.backend,
|
||||
'configuration': self.configuration,
|
||||
'prompt_template': orjson.dumps(self.prompt_template).decode(),
|
||||
'system_message': orjson.dumps(self.system_message).decode(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGE = ''
|
||||
DEFAULT_PROMPT_TEMPLATE = PromptTemplate('{instruction}')
|
||||
|
||||
|
||||
class BaichuanConfig(openllm_core.LLMConfig):
|
||||
@@ -46,32 +41,3 @@ class BaichuanConfig(openllm_core.LLMConfig):
|
||||
max_new_tokens: int = 2048
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
@property
|
||||
def default_prompt_template(self) -> str:
|
||||
return DEFAULT_PROMPT_TEMPLATE.to_string()
|
||||
|
||||
@property
|
||||
def default_system_message(self) -> str:
|
||||
return DEFAULT_SYSTEM_MESSAGE
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message
|
||||
if prompt_template is None:
|
||||
prompt_template = DEFAULT_PROMPT_TEMPLATE
|
||||
elif isinstance(prompt_template, str):
|
||||
prompt_template = PromptTemplate(template=prompt_template)
|
||||
return (
|
||||
prompt_template.with_options(system_message=system_message).format(instruction=prompt),
|
||||
{'max_new_tokens': max_new_tokens, 'top_p': top_p, 'temperature': temperature, **attrs},
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.utils import dantic
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = '{instruction}'
|
||||
|
||||
|
||||
class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
@@ -43,40 +36,9 @@ class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
'thudm/chatglm3-6b',
|
||||
],
|
||||
}
|
||||
retain_history: bool = dantic.Field(
|
||||
False,
|
||||
description='Whether to retain history given to the model. If set to True, then the model will retain given history.',
|
||||
)
|
||||
use_half_precision: bool = dantic.Field(True, description='Whether to use half precision for model.')
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 2048
|
||||
num_beams: int = 1
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
num_beams: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
chat_history: list[tuple[str, str]] | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
prompt_text = ''
|
||||
if use_default_prompt_template and chat_history is not None:
|
||||
for i, (old_query, response) in enumerate(chat_history):
|
||||
prompt_text += f'[Round {i}]\n问:{old_query}\n答:{response}\n'
|
||||
prompt_text += f'[Round {len(chat_history)}]\n问:{prompt}\n答:'
|
||||
else:
|
||||
prompt_text = prompt
|
||||
return (
|
||||
prompt_text,
|
||||
{'max_new_tokens': max_new_tokens, 'num_beams': num_beams, 'top_p': top_p, 'temperature': temperature, **attrs},
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -2,8 +2,6 @@ from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate, process_prompt
|
||||
from openllm_core.utils import dantic
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
@@ -14,32 +12,9 @@ END_KEY = '### End'
|
||||
INTRO_BLURB = (
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
|
||||
)
|
||||
# NOTE: This is the prompt that is used for generating responses using an already
|
||||
# trained model. It ends with the response key, where the job of the model is to provide
|
||||
# the completion that follows it (i.e. the response itself).
|
||||
DEFAULT_PROMPT_TEMPLATE = """{intro}
|
||||
{instruction_key}
|
||||
{instruction}
|
||||
{response_key}
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY)
|
||||
|
||||
|
||||
def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int:
|
||||
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
|
||||
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
||||
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
||||
|
||||
Args:
|
||||
tokenizer: the tokenizer
|
||||
key: the key to convert to a single token
|
||||
|
||||
Raises:
|
||||
RuntimeError: if more than one ID was generated
|
||||
|
||||
Returns:
|
||||
int: the token ID for the given key.
|
||||
"""
|
||||
token_ids = tokenizer.encode(key)
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
|
||||
@@ -66,7 +41,6 @@ class DollyV2Config(openllm_core.LLMConfig):
|
||||
'default_id': 'databricks/dolly-v2-3b',
|
||||
'model_ids': ['databricks/dolly-v2-3b', 'databricks/dolly-v2-7b', 'databricks/dolly-v2-12b'],
|
||||
}
|
||||
return_full_text: bool = dantic.Field(False, description='Whether to return the full prompt to the users.')
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
@@ -75,20 +49,8 @@ class DollyV2Config(openllm_core.LLMConfig):
|
||||
max_new_tokens: int = 256
|
||||
eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY)
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{'max_new_tokens': max_new_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature, **attrs},
|
||||
{},
|
||||
@property
|
||||
def template(self):
|
||||
return '{intro}\n{instruction_key}\n{instruction}\n{response_key}\n'.format(
|
||||
intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY
|
||||
)
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate, process_prompt
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """{context}
|
||||
{user_name}: {instruction}
|
||||
{agent}:
|
||||
"""
|
||||
|
||||
|
||||
class FalconConfig(openllm_core.LLMConfig):
|
||||
@@ -46,26 +39,6 @@ class FalconConfig(openllm_core.LLMConfig):
|
||||
num_return_sequences: int = 1
|
||||
num_beams: int = 4
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
eos_token_id: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences,
|
||||
'eos_token_id': eos_token_id,
|
||||
**attrs,
|
||||
},
|
||||
{},
|
||||
)
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return '{context}\n{user_name}: {instruction}\n{agent}:'
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate, process_prompt
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruction}\nAnswer:"""
|
||||
|
||||
|
||||
class FlanT5Config(openllm_core.LLMConfig):
|
||||
@@ -38,27 +34,6 @@ class FlanT5Config(openllm_core.LLMConfig):
|
||||
top_p: float = 0.4
|
||||
repetition_penalty = 1.0
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'repetition_penalty': repetition_penalty,
|
||||
},
|
||||
{},
|
||||
)
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return 'Answer the following question:\nQuestion: {instruction}\nAnswer:'
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate, process_prompt
|
||||
from openllm_core.utils import dantic
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
|
||||
|
||||
class GPTNeoXConfig(openllm_core.LLMConfig):
|
||||
@@ -33,24 +28,7 @@ class GPTNeoXConfig(openllm_core.LLMConfig):
|
||||
'default_id': 'eleutherai/gpt-neox-20b',
|
||||
'model_ids': ['eleutherai/gpt-neox-20b'],
|
||||
}
|
||||
use_half_precision: bool = dantic.Field(True, description='Whether to use half precision for model.')
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 100
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature},
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -1,34 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGE = """
|
||||
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
|
||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
||||
"""
|
||||
SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = '[INST]', '[/INST]', '<<SYS>>', '</s>', '<s>'
|
||||
# TODO: support history and v1 prompt implementation
|
||||
_v1_prompt, _v2_prompt = (
|
||||
"""{instruction}""",
|
||||
"""{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key}\n""".format(
|
||||
start_key=SINST_KEY,
|
||||
sys_key=SYS_KEY,
|
||||
system_message='{system_message}',
|
||||
instruction='{instruction}',
|
||||
end_key=EINST_KEY,
|
||||
),
|
||||
)
|
||||
PROMPT_MAPPING = {'v1': _v1_prompt, 'v2': _v2_prompt}
|
||||
|
||||
|
||||
def _get_prompt(model_type: t.Literal['v1', 'v2']) -> PromptTemplate:
|
||||
return PromptTemplate(PROMPT_MAPPING[model_type])
|
||||
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
|
||||
|
||||
class LlamaConfig(openllm_core.LLMConfig):
|
||||
@@ -80,31 +54,15 @@ class LlamaConfig(openllm_core.LLMConfig):
|
||||
presence_penalty: float = 0.5
|
||||
|
||||
@property
|
||||
def default_prompt_template(self) -> str:
|
||||
return DEFAULT_PROMPT_TEMPLATE('v2').to_string()
|
||||
def template(self) -> str:
|
||||
return '{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key}\n'.format(
|
||||
start_key=SINST_KEY,
|
||||
sys_key=SYS_KEY,
|
||||
system_message='{system_message}',
|
||||
instruction='{instruction}',
|
||||
end_key=EINST_KEY,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_system_message(self) -> str:
|
||||
return DEFAULT_SYSTEM_MESSAGE
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message
|
||||
if prompt_template is None:
|
||||
prompt_template = DEFAULT_PROMPT_TEMPLATE('v2')
|
||||
elif isinstance(prompt_template, str):
|
||||
prompt_template = PromptTemplate(template=prompt_template)
|
||||
return (
|
||||
prompt_template.with_options(system_message=system_message).format(instruction=prompt),
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k},
|
||||
{},
|
||||
)
|
||||
def system_message(self) -> str:
|
||||
return "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
# https://docs.mistral.ai/usage/guardrailing/
|
||||
DEFAULT_SYSTEM_MESSAGE = """Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity."""
|
||||
SINST_KEY, EINST_KEY, BOS_TOKEN, EOS_TOKEN = '[INST]', '[/INST]', '<s>', '</s>'
|
||||
DEFAULT_PROMPT_TEMPLATE = """{start_key}{start_inst} {system_message} {instruction} {end_inst}\n""".format(
|
||||
start_inst=SINST_KEY,
|
||||
end_inst=EINST_KEY,
|
||||
start_key=BOS_TOKEN,
|
||||
system_message='{system_message}',
|
||||
instruction='{instruction}',
|
||||
)
|
||||
|
||||
|
||||
class MistralConfig(openllm_core.LLMConfig):
|
||||
@@ -33,8 +22,7 @@ class MistralConfig(openllm_core.LLMConfig):
|
||||
'default_id': 'mistralai/Mistral-7B-Instruct-v0.1',
|
||||
'serialisation': 'safetensors',
|
||||
'backend': ('pt', 'vllm'),
|
||||
# NOTE: see https://docs.mistral.ai/usage/guardrailing/
|
||||
# and https://docs.mistral.ai/llm/mistral-instruct-v0.1
|
||||
# NOTE: see https://docs.mistral.ai/usage/guardrailing/ and https://docs.mistral.ai/llm/mistral-instruct-v0.1
|
||||
'model_ids': [
|
||||
'HuggingFaceH4/zephyr-7b-alpha',
|
||||
'HuggingFaceH4/zephyr-7b-beta',
|
||||
@@ -57,32 +45,16 @@ class MistralConfig(openllm_core.LLMConfig):
|
||||
presence_penalty: float = 0.5
|
||||
|
||||
@property
|
||||
def default_prompt_template(self) -> str:
|
||||
return DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
@property
|
||||
def default_system_message(self) -> str:
|
||||
return DEFAULT_SYSTEM_MESSAGE
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message
|
||||
if prompt_template is None:
|
||||
prompt_template = DEFAULT_PROMPT_TEMPLATE
|
||||
elif isinstance(prompt_template, str):
|
||||
prompt_template = PromptTemplate(template=prompt_template)
|
||||
|
||||
return (
|
||||
prompt_template.with_options(system_message=system_message).format(instruction=prompt),
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k},
|
||||
{},
|
||||
def template(self) -> str:
|
||||
return """{start_key}{start_inst} {system_message} {instruction} {end_inst}\n""".format(
|
||||
start_inst=SINST_KEY,
|
||||
end_inst=EINST_KEY,
|
||||
start_key=BOS_TOKEN,
|
||||
system_message='{system_message}',
|
||||
instruction='{instruction}',
|
||||
)
|
||||
|
||||
# NOTE: https://docs.mistral.ai/usage/guardrailing/
|
||||
@property
|
||||
def system_message(self) -> str:
|
||||
return """Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity."""
|
||||
|
||||
@@ -1,41 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate, process_prompt
|
||||
from openllm_core.utils import dantic
|
||||
|
||||
MPTPromptType = t.Literal['default', 'instruct', 'chat', 'storywriter']
|
||||
|
||||
INSTRUCTION_KEY, RESPONSE_KEY, END_KEY = '### Instruction:', '### Response:', '### End'
|
||||
INTRO_BLURB = (
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
|
||||
)
|
||||
# NOTE: This is the prompt that is used for generating responses using an already
|
||||
# trained model. It ends with the response key, where the job of the model is to provide
|
||||
# the completion that follows it (i.e. the response itself).
|
||||
_chat_prompt, _default_prompt, _instruct_prompt = (
|
||||
"""{instruction}""",
|
||||
"""{instruction}""",
|
||||
"""{intro}
|
||||
{instruction_key}
|
||||
{instruction}
|
||||
{response_key}
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY),
|
||||
)
|
||||
PROMPT_MAPPING = {
|
||||
'default': _default_prompt,
|
||||
'instruct': _instruct_prompt,
|
||||
'storywriter': _default_prompt,
|
||||
'chat': _chat_prompt,
|
||||
}
|
||||
|
||||
|
||||
def _get_prompt(model_type: str) -> str:
|
||||
return PROMPT_MAPPING[model_type]
|
||||
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
|
||||
|
||||
class MPTConfig(openllm_core.LLMConfig):
|
||||
@@ -67,46 +32,8 @@ class MPTConfig(openllm_core.LLMConfig):
|
||||
'mosaicml/mpt-30b-chat',
|
||||
],
|
||||
}
|
||||
prompt_type: MPTPromptType = dantic.Field(
|
||||
'"default"',
|
||||
description='Given prompt type for running MPT. Default will be inferred from model name if pretrained.',
|
||||
)
|
||||
max_sequence_length: int = dantic.Field(
|
||||
2048,
|
||||
description='Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)',
|
||||
)
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0
|
||||
top_p: float = 0.8
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
prompt_type: MPTPromptType | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = None
|
||||
if use_default_prompt_template:
|
||||
if prompt_type is None:
|
||||
if 'instruct' in self.model_id:
|
||||
prompt_type = 'instruct'
|
||||
elif 'storywriter' in self.model_id:
|
||||
prompt_type = 'storywriter'
|
||||
elif 'chat' in self.model_id:
|
||||
prompt_type = 'chat'
|
||||
else:
|
||||
prompt_type = 'default'
|
||||
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
|
||||
return (
|
||||
process_prompt(prompt, _template, use_default_prompt_template),
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p},
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core.prompts.prompt_template import PromptTemplate
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
|
||||
|
||||
class OPTConfig(openllm_core.LLMConfig):
|
||||
@@ -52,26 +45,3 @@ class OPTConfig(openllm_core.LLMConfig):
|
||||
temperature: float = 0.75
|
||||
max_new_tokens: int = 256
|
||||
num_return_sequences: int = 1
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences,
|
||||
},
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGE = ''
|
||||
|
||||
|
||||
class PhiConfig(openllm_core.LLMConfig):
|
||||
@@ -40,30 +34,3 @@ class PhiConfig(openllm_core.LLMConfig):
|
||||
|
||||
class SamplingParams:
|
||||
best_of: int = 1
|
||||
|
||||
@property
|
||||
def default_prompt_template(self) -> str:
|
||||
return DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
@property
|
||||
def default_system_message(self) -> str:
|
||||
return DEFAULT_SYSTEM_MESSAGE
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message
|
||||
if prompt_template is None:
|
||||
prompt_template = PromptTemplate(template=self.default_prompt_template)
|
||||
elif isinstance(prompt_template, str):
|
||||
prompt_template = PromptTemplate(template=prompt_template)
|
||||
return (
|
||||
prompt_template.with_options(system_message=system_message).format(instruction=prompt),
|
||||
{'max_new_tokens': max_new_tokens},
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -1,16 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate, process_prompt
|
||||
|
||||
SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>"""
|
||||
|
||||
|
||||
class StableLMConfig(openllm_core.LLMConfig):
|
||||
@@ -47,27 +37,15 @@ class StableLMConfig(openllm_core.LLMConfig):
|
||||
top_k: int = 0
|
||||
top_p: float = 0.9
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if 'tuned' in self._model_id and use_default_prompt_template:
|
||||
system_prompt = attrs.pop('system_prompt', SYSTEM_PROMPT)
|
||||
prompt_text = process_prompt(
|
||||
prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs
|
||||
)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
return (
|
||||
prompt_text,
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p},
|
||||
{},
|
||||
)
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return '{system_message}<|USER|>{instruction}<|ASSISTANT|>'
|
||||
|
||||
@property
|
||||
def system_message(self) -> str:
|
||||
return """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
|
||||
@@ -1,21 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = (
|
||||
'<fim-prefix>',
|
||||
'<fim-middle>',
|
||||
'<fim-suffix>',
|
||||
'<fim-pad>',
|
||||
'<|endoftext|>',
|
||||
'<FILL_HERE>',
|
||||
)
|
||||
|
||||
|
||||
class StarCoderConfig(openllm_core.LLMConfig):
|
||||
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
|
||||
@@ -44,37 +30,3 @@ class StarCoderConfig(openllm_core.LLMConfig):
|
||||
top_p: float = 0.95
|
||||
pad_token_id: int = 49152
|
||||
repetition_penalty: float = 1.2
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
|
||||
if fim_mode:
|
||||
try:
|
||||
prefix, suffix = prompt.split(FIM_INDICATOR)
|
||||
except Exception as err:
|
||||
raise ValueError(f'Only one {FIM_INDICATOR} allowed in prompt') from err
|
||||
prompt_text = f'{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}'
|
||||
else:
|
||||
prompt_text = prompt
|
||||
# XXX: This value for pad_token_id is currently a hack, need more investigate why the default starcoder doesn't include the same value as santacoder EOD
|
||||
return (
|
||||
prompt_text,
|
||||
{
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'repetition_penalty': repetition_penalty,
|
||||
'pad_token_id': 49152,
|
||||
**attrs,
|
||||
},
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGE = ''
|
||||
DEFAULT_PROMPT_TEMPLATE = '{instruction}'
|
||||
|
||||
|
||||
class YiConfig(openllm_core.LLMConfig):
|
||||
@@ -39,35 +34,3 @@ class YiConfig(openllm_core.LLMConfig):
|
||||
class SamplingParams:
|
||||
best_of: int = 1
|
||||
presence_penalty: float = 0.5
|
||||
|
||||
@property
|
||||
def default_prompt_template(self) -> str:
|
||||
return DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
@property
|
||||
def default_system_message(self) -> str:
|
||||
return DEFAULT_SYSTEM_MESSAGE
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
max_new_tokens: int,
|
||||
repetition_penalty: float,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message
|
||||
if prompt_template is None:
|
||||
prompt_template = DEFAULT_PROMPT_TEMPLATE
|
||||
elif isinstance(prompt_template, str):
|
||||
prompt_template = PromptTemplate(template=prompt_template)
|
||||
|
||||
return (
|
||||
prompt_template.with_options(system_message=system_message).format(instruction=prompt),
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k},
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .prompt_template import PromptTemplate as PromptTemplate, process_prompt as process_prompt
|
||||
@@ -1,61 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
from .utils import default_formatter
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
_input_variables: t.Sequence[str] = attr.field(init=False)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._input_variables = default_formatter.extract_template_variables(self.template)
|
||||
|
||||
def with_options(self, **attrs: t.Any) -> PromptTemplate:
|
||||
prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables}
|
||||
o = attr.evolve(self, template=self.template.format(**prompt_variables))
|
||||
_object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template))
|
||||
return o
|
||||
|
||||
def to_string(self) -> str:
|
||||
return self.template
|
||||
|
||||
def format(self, **attrs: t.Any) -> str:
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables}
|
||||
try:
|
||||
return self.template.format(**prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
|
||||
) from None
|
||||
|
||||
|
||||
# TODO: remove process_prompt after refactor config for all models
|
||||
def process_prompt(
|
||||
prompt: str, template: PromptTemplate | str | None = None, use_prompt_template: bool = True, **attrs: t.Any
|
||||
) -> str:
|
||||
# Currently, all default prompt will always have `instruction` key.
|
||||
if not use_prompt_template:
|
||||
return prompt
|
||||
elif template is None:
|
||||
raise ValueError("'template' can't be None while 'use_prompt_template=False'")
|
||||
if isinstance(template, PromptTemplate):
|
||||
template = template.to_string()
|
||||
template_variables = default_formatter.extract_template_variables(template)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if 'instruction' in prompt_variables:
|
||||
raise RuntimeError(
|
||||
"'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'"
|
||||
)
|
||||
try:
|
||||
return template.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template."
|
||||
) from None
|
||||
@@ -1,25 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0:
|
||||
raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras:
|
||||
raise KeyError(f'Extra params passed: {extras}')
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
@@ -115,7 +115,7 @@ def calc_dir_size(path:PathType)->int:return sum(f.stat().st_size for f in _Path
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def generate_hash_from_file(f:str,algorithm:t.Literal['md5','sha1']='sha1')->str:return str(getattr(hashlib,algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest())
|
||||
def getenv(env:str,default:t.Any=None,var:t.Sequence[str]|None=None)->t.Any:
|
||||
env_key={f'OPENLLM_{env.upper()}',env.upper()}
|
||||
env_key={env.upper(),f'OPENLLM_{env.upper()}'}
|
||||
if var is not None:env_key=set(var)|env_key
|
||||
def callback(k:str)->t.Any:
|
||||
_var = os.getenv(k)
|
||||
|
||||
@@ -77,8 +77,6 @@ def Runner(
|
||||
'serialisation': first_not_none(
|
||||
attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']
|
||||
),
|
||||
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
|
||||
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ from openllm_core._typing_compat import (
|
||||
TupleAny,
|
||||
)
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.utils import (
|
||||
DEBUG,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
@@ -129,8 +128,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_max_model_len: int | None
|
||||
_prompt_template: PromptTemplate | None
|
||||
_system_message: str | None
|
||||
|
||||
__llm_dtype__: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto'
|
||||
__llm_torch_dtype__: 'torch.dtype' = None
|
||||
@@ -148,8 +145,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
model_id,
|
||||
model_version=None,
|
||||
model_tag=None,
|
||||
prompt_template=None,
|
||||
system_message=None,
|
||||
llm_config=None,
|
||||
backend=None,
|
||||
*args,
|
||||
@@ -192,8 +187,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
serialisation=serialisation,
|
||||
local=_local,
|
||||
max_model_len=max_model_len,
|
||||
prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template,
|
||||
system_message=system_message,
|
||||
LLM__model_attrs=model_attrs,
|
||||
LLM__tokenizer_attrs=tokenizer_attrs,
|
||||
llm_dtype__=dtype.lower(),
|
||||
@@ -480,16 +473,24 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
|
||||
request_id = gen_random_uuid() if request_id is None else request_id
|
||||
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
|
||||
async for out in self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
):
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
delta_outputs = [None] * len(generated.outputs)
|
||||
if generated.finished:
|
||||
break
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
try:
|
||||
generator = self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Failed to start generation task: {err}') from err
|
||||
|
||||
try:
|
||||
async for out in generator:
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
delta_outputs = [None] * len(generated.outputs)
|
||||
if generated.finished:
|
||||
break
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Exception caught during generation: {err}') from err
|
||||
|
||||
@@ -18,7 +18,6 @@ from openllm_core._typing_compat import (
|
||||
M,
|
||||
T,
|
||||
)
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.utils.representation import ReprArgs
|
||||
|
||||
from ._quantisation import QuantizationConfig
|
||||
@@ -48,8 +47,6 @@ class LLM(Generic[M, T]):
|
||||
_adapter_map: Optional[AdapterMap]
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_prompt_template: Optional[PromptTemplate]
|
||||
_system_message: Optional[str]
|
||||
|
||||
__llm_dtype__: Dtype = ...
|
||||
__llm_torch_dtype__: Optional[torch.dtype] = ...
|
||||
@@ -74,8 +71,6 @@ class LLM(Generic[M, T]):
|
||||
model_id: str,
|
||||
model_version: Optional[str] = ...,
|
||||
model_tag: Optional[Union[str, Tag]] = ...,
|
||||
prompt_template: Optional[Union[str, PromptTemplate]] = ...,
|
||||
system_message: Optional[str] = ...,
|
||||
llm_config: Optional[LLMConfig] = ...,
|
||||
backend: Optional[LiteralBackend] = ...,
|
||||
*args: Any,
|
||||
|
||||
@@ -27,7 +27,7 @@ def registry(cls=None, *, alias=None):
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def runner(llm):
|
||||
def runner(llm: openllm.LLM):
|
||||
from ._strategies import CascadingResourceStrategy
|
||||
|
||||
try:
|
||||
@@ -35,19 +35,6 @@ def runner(llm):
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err
|
||||
|
||||
if llm._prompt_template:
|
||||
prompt_template = llm._prompt_template.to_string()
|
||||
elif hasattr(llm.config, 'default_prompt_template'):
|
||||
prompt_template = llm.config.default_prompt_template
|
||||
else:
|
||||
prompt_template = None
|
||||
if llm._system_message:
|
||||
system_message = llm._system_message
|
||||
elif hasattr(llm.config, 'default_system_message'):
|
||||
system_message = llm.config.default_system_message
|
||||
else:
|
||||
system_message = None
|
||||
|
||||
return types.new_class(
|
||||
llm.config.__class__.__name__[:-6] + 'Runner',
|
||||
(bentoml.Runner,),
|
||||
@@ -80,8 +67,8 @@ def runner(llm):
|
||||
('llm_tag', llm.tag),
|
||||
),
|
||||
'has_adapters': llm.has_adapters,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
'template': llm.config.template,
|
||||
'system_message': llm.config.system_message,
|
||||
}
|
||||
),
|
||||
)(
|
||||
@@ -101,7 +88,7 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
def __init__(self, llm):
|
||||
if not is_ctranslate_available():
|
||||
raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`')
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.modle, llm.tokenizer
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
|
||||
@@ -83,8 +83,8 @@ class Runner(Protocol[Mo, To]):
|
||||
config: LLMConfig = ...
|
||||
backend: LiteralBackend = ...
|
||||
has_adapters: bool = ...
|
||||
prompt_template: Optional[str] = ...
|
||||
system_message: Optional[str] = ...
|
||||
template: str = ...
|
||||
system_message: str = ...
|
||||
|
||||
class generate_iterator(RunnerMethod[List[int], AsyncGenerator[str, None]]):
|
||||
@staticmethod
|
||||
|
||||
@@ -13,8 +13,6 @@ logger = logging.getLogger(__name__)
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
prompt_template=svars.prompt_template,
|
||||
system_message=svars.system_message,
|
||||
serialisation=svars.serialization,
|
||||
adapter_map=svars.adapter_map,
|
||||
trust_remote_code=svars.trust_remote_code,
|
||||
@@ -50,8 +48,6 @@ _Metadata = openllm.MetadataOutput(
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm.config.model_dump_json().decode(),
|
||||
prompt_template=llm.runner.prompt_template,
|
||||
system_message=llm.runner.system_message,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,5 @@ from openllm_core.utils import ENV_VARS_TRUE_VALUES
|
||||
model_id = os.environ['OPENLLM_MODEL_ID']
|
||||
model_tag = None
|
||||
adapter_map = orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None)))
|
||||
prompt_template = os.getenv('OPENLLM_PROMPT_TEMPLATE')
|
||||
system_message = os.getenv('OPENLLM_SYSTEM_MESSAGE')
|
||||
serialization = os.getenv('OPENLLM_SERIALIZATION', default='safetensors')
|
||||
trust_remote_code = str(os.getenv('TRUST_REMOTE_CODE', default=str(False))).upper() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
@@ -5,7 +5,5 @@ from openllm_core._typing_compat import LiteralSerialisation
|
||||
model_id: str = ...
|
||||
model_tag: Optional[str] = ...
|
||||
adapter_map: Optional[Dict[str, str]] = ...
|
||||
prompt_template: Optional[str] = ...
|
||||
system_message: Optional[str] = ...
|
||||
serialization: LiteralSerialisation = ...
|
||||
trust_remote_code: bool = ...
|
||||
|
||||
@@ -4,6 +4,4 @@ model_id = '{__model_id__}' # openllm: model id
|
||||
model_tag = '{__model_tag__}' # openllm: model tag
|
||||
adapter_map = orjson.loads("""{__model_adapter_map__}""") # openllm: model adapter map
|
||||
serialization = '{__model_serialization__}' # openllm: model serialization
|
||||
prompt_template = {__model_prompt_template__} # openllm: model prompt template
|
||||
system_message = {__model_system_message__} # openllm: model system message
|
||||
trust_remote_code = {__model_trust_remote_code__} # openllm: model trust remote code
|
||||
|
||||
@@ -83,8 +83,6 @@ def construct_docker_options(
|
||||
None,
|
||||
llm._serialisation,
|
||||
llm,
|
||||
llm._system_message,
|
||||
llm._prompt_template,
|
||||
use_current_env=False,
|
||||
)
|
||||
# XXX: We need to quote this so that the envvar in container recognize as valid json
|
||||
@@ -101,8 +99,6 @@ def construct_docker_options(
|
||||
OPENLLM_MODEL_ID = '# openllm: model id'
|
||||
OPENLLM_MODEL_TAG = '# openllm: model tag'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
OPENLLM_MODEL_PROMPT_TEMPLATE = '# openllm: model prompt template'
|
||||
OPENLLM_MODEL_SYSTEM_MESSAGE = '# openllm: model system message'
|
||||
OPENLLM_MODEL_SERIALIZATION = '# openllm: model serialization'
|
||||
OPENLLM_MODEL_TRUST_REMOTE_CODE = '# openllm: model trust remote code'
|
||||
|
||||
@@ -140,16 +136,6 @@ class ModelAdapterMapFormatter(_ServiceVarsFormatter):
|
||||
identifier = OPENLLM_MODEL_ADAPTER_MAP
|
||||
|
||||
|
||||
class ModelPromptTemplateFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_prompt_template__'
|
||||
identifier = OPENLLM_MODEL_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class ModelSystemMessageFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_system_message__'
|
||||
identifier = OPENLLM_MODEL_SYSTEM_MESSAGE
|
||||
|
||||
|
||||
class ModelSerializationFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_serialization__'
|
||||
identifier = OPENLLM_MODEL_SERIALIZATION
|
||||
@@ -187,16 +173,6 @@ def write_service(llm, llm_fs, adapter_map):
|
||||
src_contents[i] = serialization_formatter.parse_line(it)
|
||||
elif trust_remote_code_formatter.identifier in it:
|
||||
src_contents[i] = trust_remote_code_formatter.parse_line(it)
|
||||
elif OPENLLM_MODEL_PROMPT_TEMPLATE in it:
|
||||
if llm._prompt_template:
|
||||
src_contents[i] = ModelPromptTemplateFormatter(f'"""{llm._prompt_template.to_string()}"""').parse_line(it)
|
||||
else:
|
||||
src_contents[i] = ModelPromptTemplateFormatter(str(None)).parse_line(it)
|
||||
elif OPENLLM_MODEL_SYSTEM_MESSAGE in it:
|
||||
if llm._system_message:
|
||||
src_contents[i] = ModelSystemMessageFormatter(f'"""{llm._system_message}"""').parse_line(it)
|
||||
else:
|
||||
src_contents[i] = ModelSystemMessageFormatter(str(None)).parse_line(it)
|
||||
|
||||
script = f"# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
if SHOW_CODEGEN:
|
||||
|
||||
@@ -105,7 +105,7 @@ async def cohere_generate(req, llm):
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
request_id = gen_random_uuid('cohere-generate')
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
if request.prompt_vars is not None:
|
||||
prompt = request.prompt.format(**request.prompt_vars)
|
||||
@@ -202,7 +202,7 @@ async def cohere_chat(req, llm):
|
||||
_transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
|
||||
@@ -156,7 +156,7 @@ async def chat_completions(req, llm):
|
||||
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
@@ -272,14 +272,9 @@ async def completions(req, llm):
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch
|
||||
return error_response(
|
||||
HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`."
|
||||
)
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('cmpl')
|
||||
created_time = int(time.monotonic())
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
|
||||
@@ -135,13 +135,11 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
|
||||
composed = openllm.utils.compose(
|
||||
_OpenLLM_GenericInternalConfig().to_click_options,
|
||||
_OpenLLM_GenericInternalConfig.parse,
|
||||
_http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'),
|
||||
dtype_option(factory=cog.optgroup),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
system_message_option(factory=cog.optgroup),
|
||||
prompt_template_file_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
cors_option(factory=cog.optgroup),
|
||||
@@ -318,27 +316,6 @@ def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
)(f)
|
||||
|
||||
|
||||
def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--system-message',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_SYSTEM_MESSAGE',
|
||||
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--prompt-template-file',
|
||||
type=click.File(),
|
||||
default=None,
|
||||
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--backend',
|
||||
|
||||
@@ -37,8 +37,6 @@ def _start(
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
@@ -60,8 +58,6 @@ def _start(
|
||||
Args:
|
||||
model_id: The model id to start this LLMServer
|
||||
timeout: The server timeout
|
||||
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
|
||||
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
|
||||
workers_per_resource: Number of workers per resource assigned.
|
||||
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -88,10 +84,6 @@ def _start(
|
||||
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
|
||||
|
||||
args: list[str] = [model_id]
|
||||
if system_message:
|
||||
args.extend(['--system-message', system_message])
|
||||
if prompt_template_file:
|
||||
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if timeout:
|
||||
args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
@@ -129,8 +121,6 @@ def _build(
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
dockerfile_template: str | None = None,
|
||||
@@ -155,8 +145,6 @@ def _build(
|
||||
model_id: The model id to build this BentoLLM
|
||||
model_version: Optional model version for this given LLM
|
||||
bento_version: Optional bento veresion for this given BentoLLM
|
||||
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
|
||||
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
@@ -209,10 +197,6 @@ def _build(
|
||||
args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite:
|
||||
args.append('--overwrite')
|
||||
if system_message:
|
||||
args.extend(['--system-message', system_message])
|
||||
if prompt_template_file:
|
||||
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if adapter_map:
|
||||
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version:
|
||||
|
||||
@@ -100,11 +100,9 @@ from ._factory import (
|
||||
model_name_argument,
|
||||
model_version_option,
|
||||
parse_config_options,
|
||||
prompt_template_file_option,
|
||||
quantize_option,
|
||||
serialisation_option,
|
||||
start_decorator,
|
||||
system_message_option,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -404,8 +402,6 @@ def start_command(
|
||||
model_id: str,
|
||||
server_timeout: int,
|
||||
model_version: str | None,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: LiteralQuantise | None,
|
||||
@@ -437,7 +433,6 @@ def start_command(
|
||||
)
|
||||
|
||||
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
|
||||
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
@@ -467,8 +462,6 @@ def start_command(
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
backend=backend,
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
@@ -495,8 +488,6 @@ def start_command(
|
||||
adapter_map,
|
||||
serialisation,
|
||||
llm,
|
||||
system_message,
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
server = bentoml.HTTPServer('_service:svc', **server_attrs)
|
||||
@@ -541,8 +532,6 @@ def start_grpc_command(
|
||||
model_id: str,
|
||||
server_timeout: int,
|
||||
model_version: str | None,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: LiteralQuantise | None,
|
||||
@@ -577,7 +566,6 @@ def start_grpc_command(
|
||||
)
|
||||
|
||||
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
|
||||
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
@@ -604,8 +592,6 @@ def start_grpc_command(
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
backend=backend,
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
@@ -634,8 +620,6 @@ def start_grpc_command(
|
||||
adapter_map,
|
||||
serialisation,
|
||||
llm,
|
||||
system_message,
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
server = bentoml.GrpcServer('_service:svc', **server_attrs)
|
||||
@@ -654,18 +638,7 @@ def start_grpc_command(
|
||||
|
||||
|
||||
def process_environ(
|
||||
config,
|
||||
server_timeout,
|
||||
wpr,
|
||||
device,
|
||||
cors,
|
||||
model_id,
|
||||
adapter_map,
|
||||
serialisation,
|
||||
llm,
|
||||
system_message,
|
||||
prompt_template,
|
||||
use_current_env=True,
|
||||
config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True
|
||||
) -> t.Dict[str, t.Any]:
|
||||
environ = parse_config_options(
|
||||
config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {}
|
||||
@@ -685,10 +658,6 @@ def process_environ(
|
||||
)
|
||||
if llm.quantise:
|
||||
environ['QUANTIZE'] = str(llm.quantise)
|
||||
if system_message:
|
||||
environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template:
|
||||
environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
return environ
|
||||
|
||||
|
||||
@@ -929,8 +898,6 @@ class BuildBentoOutput(t.TypedDict):
|
||||
)
|
||||
@dtype_option
|
||||
@backend_option
|
||||
@system_message_option
|
||||
@prompt_template_file_option
|
||||
@click.option(
|
||||
'--bento-version',
|
||||
type=str,
|
||||
@@ -1004,8 +971,6 @@ def build_command(
|
||||
adapter_id: tuple[str, ...],
|
||||
build_ctx: str | None,
|
||||
backend: LiteralBackend | None,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
model_version: str | None,
|
||||
dockerfile_template: t.TextIO | None,
|
||||
containerize: bool,
|
||||
@@ -1051,12 +1016,9 @@ def build_command(
|
||||
|
||||
state = ItemState.NOT_FOUND
|
||||
|
||||
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
backend=backend,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
@@ -1075,19 +1037,7 @@ def build_command(
|
||||
llm._tag = model.tag
|
||||
|
||||
os.environ.update(
|
||||
**process_environ(
|
||||
llm.config,
|
||||
llm.config['timeout'],
|
||||
1.0,
|
||||
None,
|
||||
True,
|
||||
llm.model_id,
|
||||
None,
|
||||
llm._serialisation,
|
||||
llm,
|
||||
llm._system_message,
|
||||
llm._prompt_template,
|
||||
)
|
||||
**process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm)
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,33 +1,82 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import traceback
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
from bentoml_cli.utils import opt_callback
|
||||
|
||||
import openllm
|
||||
import openllm_core
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import model_complete_envvar
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0:
|
||||
raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras:
|
||||
raise KeyError(f'Extra params passed: {extras}')
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
_input_variables: t.Sequence[str] = attr.field(init=False)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._input_variables = default_formatter.extract_template_variables(self.template)
|
||||
|
||||
def with_options(self, **attrs: t.Any) -> PromptTemplate:
|
||||
prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables}
|
||||
o = attr.evolve(self, template=self.template.format(**prompt_variables))
|
||||
_object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template))
|
||||
return o
|
||||
|
||||
def format(self, **attrs: t.Any) -> str:
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables}
|
||||
try:
|
||||
return self.template.format(**prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
|
||||
) from None
|
||||
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]),
|
||||
shell_complete=model_complete_envvar,
|
||||
)
|
||||
@click.argument('model_id', shell_complete=model_complete_envvar)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option('--format', type=click.STRING, default=None)
|
||||
@click.option('--prompt-template-file', type=click.File(), default=None)
|
||||
@click.option('--chat-template-file', type=click.File(), default=None)
|
||||
@click.option('--system-message', type=str, default=None)
|
||||
@click.option(
|
||||
'--add-generation-prompt/--no-add-generation-prompt',
|
||||
default=False,
|
||||
help='See https://huggingface.co/docs/transformers/main/chat_templating#what-template-should-i-use. This only applicable if model-id is a HF model_id',
|
||||
)
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
help="Define additional prompt variables. (format: ``--opt system_message='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
@@ -35,45 +84,96 @@ logger = logging.getLogger(__name__)
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(
|
||||
ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any
|
||||
ctx: click.Context,
|
||||
/,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
chat_template_file: t.IO[t.Any] | None,
|
||||
system_message: str | None,
|
||||
add_generation_prompt: bool,
|
||||
_memoized: dict[str, t.Any],
|
||||
**_: t.Any,
|
||||
) -> str | None:
|
||||
"""Get the default prompt used by OpenLLM."""
|
||||
module = getattr(openllm_core.config, f'configuration_{model_name}')
|
||||
"""Helpers for generating prompts.
|
||||
|
||||
\b
|
||||
It accepts remote HF model_ids as well as model name passed to `openllm start`.
|
||||
|
||||
If you pass in a HF model_id, then it will use the tokenizer to generate the prompt.
|
||||
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there"
|
||||
```
|
||||
|
||||
If you need change the prompt template, you can create the template file that contains the jina2 template through `--chat-template-file`
|
||||
See https://huggingface.co/docs/transformers/main/chat_templating#templates-for-chat-models for more details.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there" --chat-template-file template.jinja2
|
||||
```
|
||||
|
||||
\b
|
||||
|
||||
If you pass a model name, then it will use OpenLLM configuration to generate the prompt.
|
||||
Note that this is mainly for utilities, as OpenLLM won't use these prompts to format for you.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt mistral "Hello there"
|
||||
"""
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
try:
|
||||
template = getattr(module, 'DEFAULT_PROMPT_TEMPLATE', None)
|
||||
prompt_mapping = getattr(module, 'PROMPT_MAPPING', None)
|
||||
if template is None:
|
||||
raise click.BadArgumentUsage(f'model {model_name} does not have a default prompt template') from None
|
||||
if callable(template):
|
||||
if format is None:
|
||||
if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None:
|
||||
raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.')
|
||||
raise click.BadOptionUsage(
|
||||
'format',
|
||||
f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})",
|
||||
)
|
||||
if prompt_mapping is None:
|
||||
raise click.BadArgumentUsage(
|
||||
f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.'
|
||||
) from None
|
||||
if format not in prompt_mapping:
|
||||
raise click.BadOptionUsage(
|
||||
'format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})'
|
||||
)
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
|
||||
if prompt_template_file and chat_template_file:
|
||||
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
|
||||
|
||||
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
|
||||
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
|
||||
)
|
||||
if model_id in acceptable:
|
||||
logger.warning(
|
||||
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
|
||||
)
|
||||
config = openllm.AutoConfig.for_model(model_id)
|
||||
template = prompt_template_file.read() if prompt_template_file is not None else config.template
|
||||
system_message = system_message or config.system_message
|
||||
|
||||
try:
|
||||
# backward-compatible. TO BE REMOVED once every model has default system message and prompt template.
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
formatted = (
|
||||
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
|
||||
)
|
||||
except RuntimeError as err:
|
||||
logger.debug('Exception caught while formatting prompt: %s', err)
|
||||
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(
|
||||
prompt, prompt_template=_prompt_template
|
||||
)[0]
|
||||
termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise click.ClickException(f'Failed to determine a default prompt template for {model_name}.') from err
|
||||
ctx.fail(str(err))
|
||||
else:
|
||||
import transformers
|
||||
|
||||
trust_remote_code = openllm.utils.check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
if chat_template_file is not None:
|
||||
chat_template_file = chat_template_file.read()
|
||||
if system_message is None:
|
||||
logger.warning('system-message is not provided, using default infer from the model architecture.\n')
|
||||
for architecture in config.architectures:
|
||||
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
|
||||
system_message = (
|
||||
openllm.AutoConfig.infer_class_from_name(
|
||||
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
|
||||
)
|
||||
.model_construct_env()
|
||||
.system_message
|
||||
)
|
||||
break
|
||||
else:
|
||||
ctx.fail(
|
||||
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
|
||||
)
|
||||
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
|
||||
)
|
||||
|
||||
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -171,7 +171,7 @@ def test_click_conversion(gen_settings: ModelSettings):
|
||||
return attrs
|
||||
|
||||
cl_ = make_llm_config('ClickConversionLLM', gen_settings)
|
||||
wrapped = cl_.to_click_options(cli_mock)
|
||||
wrapped = cl_.parse(cli_mock)
|
||||
filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union}
|
||||
click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith('fake_')]
|
||||
assert len(filtered) == len(click_options_filtered)
|
||||
|
||||
Reference in New Issue
Block a user