mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-19 07:06:02 -05:00
chore: cleanup unused prompt templates (#713)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -77,8 +77,6 @@ def Runner(
|
||||
'serialisation': first_not_none(
|
||||
attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']
|
||||
),
|
||||
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
|
||||
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ from openllm_core._typing_compat import (
|
||||
TupleAny,
|
||||
)
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.utils import (
|
||||
DEBUG,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
@@ -129,8 +128,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_max_model_len: int | None
|
||||
_prompt_template: PromptTemplate | None
|
||||
_system_message: str | None
|
||||
|
||||
__llm_dtype__: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto'
|
||||
__llm_torch_dtype__: 'torch.dtype' = None
|
||||
@@ -148,8 +145,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
model_id,
|
||||
model_version=None,
|
||||
model_tag=None,
|
||||
prompt_template=None,
|
||||
system_message=None,
|
||||
llm_config=None,
|
||||
backend=None,
|
||||
*args,
|
||||
@@ -192,8 +187,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
serialisation=serialisation,
|
||||
local=_local,
|
||||
max_model_len=max_model_len,
|
||||
prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template,
|
||||
system_message=system_message,
|
||||
LLM__model_attrs=model_attrs,
|
||||
LLM__tokenizer_attrs=tokenizer_attrs,
|
||||
llm_dtype__=dtype.lower(),
|
||||
@@ -480,16 +473,24 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
|
||||
request_id = gen_random_uuid() if request_id is None else request_id
|
||||
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
|
||||
async for out in self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
):
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
delta_outputs = [None] * len(generated.outputs)
|
||||
if generated.finished:
|
||||
break
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
try:
|
||||
generator = self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Failed to start generation task: {err}') from err
|
||||
|
||||
try:
|
||||
async for out in generator:
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
delta_outputs = [None] * len(generated.outputs)
|
||||
if generated.finished:
|
||||
break
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Exception caught during generation: {err}') from err
|
||||
|
||||
@@ -18,7 +18,6 @@ from openllm_core._typing_compat import (
|
||||
M,
|
||||
T,
|
||||
)
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.utils.representation import ReprArgs
|
||||
|
||||
from ._quantisation import QuantizationConfig
|
||||
@@ -48,8 +47,6 @@ class LLM(Generic[M, T]):
|
||||
_adapter_map: Optional[AdapterMap]
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_prompt_template: Optional[PromptTemplate]
|
||||
_system_message: Optional[str]
|
||||
|
||||
__llm_dtype__: Dtype = ...
|
||||
__llm_torch_dtype__: Optional[torch.dtype] = ...
|
||||
@@ -74,8 +71,6 @@ class LLM(Generic[M, T]):
|
||||
model_id: str,
|
||||
model_version: Optional[str] = ...,
|
||||
model_tag: Optional[Union[str, Tag]] = ...,
|
||||
prompt_template: Optional[Union[str, PromptTemplate]] = ...,
|
||||
system_message: Optional[str] = ...,
|
||||
llm_config: Optional[LLMConfig] = ...,
|
||||
backend: Optional[LiteralBackend] = ...,
|
||||
*args: Any,
|
||||
|
||||
@@ -27,7 +27,7 @@ def registry(cls=None, *, alias=None):
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def runner(llm):
|
||||
def runner(llm: openllm.LLM):
|
||||
from ._strategies import CascadingResourceStrategy
|
||||
|
||||
try:
|
||||
@@ -35,19 +35,6 @@ def runner(llm):
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err
|
||||
|
||||
if llm._prompt_template:
|
||||
prompt_template = llm._prompt_template.to_string()
|
||||
elif hasattr(llm.config, 'default_prompt_template'):
|
||||
prompt_template = llm.config.default_prompt_template
|
||||
else:
|
||||
prompt_template = None
|
||||
if llm._system_message:
|
||||
system_message = llm._system_message
|
||||
elif hasattr(llm.config, 'default_system_message'):
|
||||
system_message = llm.config.default_system_message
|
||||
else:
|
||||
system_message = None
|
||||
|
||||
return types.new_class(
|
||||
llm.config.__class__.__name__[:-6] + 'Runner',
|
||||
(bentoml.Runner,),
|
||||
@@ -80,8 +67,8 @@ def runner(llm):
|
||||
('llm_tag', llm.tag),
|
||||
),
|
||||
'has_adapters': llm.has_adapters,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
'template': llm.config.template,
|
||||
'system_message': llm.config.system_message,
|
||||
}
|
||||
),
|
||||
)(
|
||||
@@ -101,7 +88,7 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
def __init__(self, llm):
|
||||
if not is_ctranslate_available():
|
||||
raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`')
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.modle, llm.tokenizer
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
|
||||
@@ -83,8 +83,8 @@ class Runner(Protocol[Mo, To]):
|
||||
config: LLMConfig = ...
|
||||
backend: LiteralBackend = ...
|
||||
has_adapters: bool = ...
|
||||
prompt_template: Optional[str] = ...
|
||||
system_message: Optional[str] = ...
|
||||
template: str = ...
|
||||
system_message: str = ...
|
||||
|
||||
class generate_iterator(RunnerMethod[List[int], AsyncGenerator[str, None]]):
|
||||
@staticmethod
|
||||
|
||||
@@ -13,8 +13,6 @@ logger = logging.getLogger(__name__)
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
prompt_template=svars.prompt_template,
|
||||
system_message=svars.system_message,
|
||||
serialisation=svars.serialization,
|
||||
adapter_map=svars.adapter_map,
|
||||
trust_remote_code=svars.trust_remote_code,
|
||||
@@ -50,8 +48,6 @@ _Metadata = openllm.MetadataOutput(
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm.config.model_dump_json().decode(),
|
||||
prompt_template=llm.runner.prompt_template,
|
||||
system_message=llm.runner.system_message,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,5 @@ from openllm_core.utils import ENV_VARS_TRUE_VALUES
|
||||
model_id = os.environ['OPENLLM_MODEL_ID']
|
||||
model_tag = None
|
||||
adapter_map = orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None)))
|
||||
prompt_template = os.getenv('OPENLLM_PROMPT_TEMPLATE')
|
||||
system_message = os.getenv('OPENLLM_SYSTEM_MESSAGE')
|
||||
serialization = os.getenv('OPENLLM_SERIALIZATION', default='safetensors')
|
||||
trust_remote_code = str(os.getenv('TRUST_REMOTE_CODE', default=str(False))).upper() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
@@ -5,7 +5,5 @@ from openllm_core._typing_compat import LiteralSerialisation
|
||||
model_id: str = ...
|
||||
model_tag: Optional[str] = ...
|
||||
adapter_map: Optional[Dict[str, str]] = ...
|
||||
prompt_template: Optional[str] = ...
|
||||
system_message: Optional[str] = ...
|
||||
serialization: LiteralSerialisation = ...
|
||||
trust_remote_code: bool = ...
|
||||
|
||||
@@ -4,6 +4,4 @@ model_id = '{__model_id__}' # openllm: model id
|
||||
model_tag = '{__model_tag__}' # openllm: model tag
|
||||
adapter_map = orjson.loads("""{__model_adapter_map__}""") # openllm: model adapter map
|
||||
serialization = '{__model_serialization__}' # openllm: model serialization
|
||||
prompt_template = {__model_prompt_template__} # openllm: model prompt template
|
||||
system_message = {__model_system_message__} # openllm: model system message
|
||||
trust_remote_code = {__model_trust_remote_code__} # openllm: model trust remote code
|
||||
|
||||
@@ -83,8 +83,6 @@ def construct_docker_options(
|
||||
None,
|
||||
llm._serialisation,
|
||||
llm,
|
||||
llm._system_message,
|
||||
llm._prompt_template,
|
||||
use_current_env=False,
|
||||
)
|
||||
# XXX: We need to quote this so that the envvar in container recognize as valid json
|
||||
@@ -101,8 +99,6 @@ def construct_docker_options(
|
||||
OPENLLM_MODEL_ID = '# openllm: model id'
|
||||
OPENLLM_MODEL_TAG = '# openllm: model tag'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
OPENLLM_MODEL_PROMPT_TEMPLATE = '# openllm: model prompt template'
|
||||
OPENLLM_MODEL_SYSTEM_MESSAGE = '# openllm: model system message'
|
||||
OPENLLM_MODEL_SERIALIZATION = '# openllm: model serialization'
|
||||
OPENLLM_MODEL_TRUST_REMOTE_CODE = '# openllm: model trust remote code'
|
||||
|
||||
@@ -140,16 +136,6 @@ class ModelAdapterMapFormatter(_ServiceVarsFormatter):
|
||||
identifier = OPENLLM_MODEL_ADAPTER_MAP
|
||||
|
||||
|
||||
class ModelPromptTemplateFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_prompt_template__'
|
||||
identifier = OPENLLM_MODEL_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class ModelSystemMessageFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_system_message__'
|
||||
identifier = OPENLLM_MODEL_SYSTEM_MESSAGE
|
||||
|
||||
|
||||
class ModelSerializationFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_serialization__'
|
||||
identifier = OPENLLM_MODEL_SERIALIZATION
|
||||
@@ -187,16 +173,6 @@ def write_service(llm, llm_fs, adapter_map):
|
||||
src_contents[i] = serialization_formatter.parse_line(it)
|
||||
elif trust_remote_code_formatter.identifier in it:
|
||||
src_contents[i] = trust_remote_code_formatter.parse_line(it)
|
||||
elif OPENLLM_MODEL_PROMPT_TEMPLATE in it:
|
||||
if llm._prompt_template:
|
||||
src_contents[i] = ModelPromptTemplateFormatter(f'"""{llm._prompt_template.to_string()}"""').parse_line(it)
|
||||
else:
|
||||
src_contents[i] = ModelPromptTemplateFormatter(str(None)).parse_line(it)
|
||||
elif OPENLLM_MODEL_SYSTEM_MESSAGE in it:
|
||||
if llm._system_message:
|
||||
src_contents[i] = ModelSystemMessageFormatter(f'"""{llm._system_message}"""').parse_line(it)
|
||||
else:
|
||||
src_contents[i] = ModelSystemMessageFormatter(str(None)).parse_line(it)
|
||||
|
||||
script = f"# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
if SHOW_CODEGEN:
|
||||
|
||||
@@ -105,7 +105,7 @@ async def cohere_generate(req, llm):
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
request_id = gen_random_uuid('cohere-generate')
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
if request.prompt_vars is not None:
|
||||
prompt = request.prompt.format(**request.prompt_vars)
|
||||
@@ -202,7 +202,7 @@ async def cohere_chat(req, llm):
|
||||
_transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
|
||||
@@ -156,7 +156,7 @@ async def chat_completions(req, llm):
|
||||
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
@@ -272,14 +272,9 @@ async def completions(req, llm):
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch
|
||||
return error_response(
|
||||
HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`."
|
||||
)
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('cmpl')
|
||||
created_time = int(time.monotonic())
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
|
||||
Reference in New Issue
Block a user