chore: cleanup unused prompt templates (#713)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-21 01:56:51 -05:00
committed by GitHub
parent e6b9a749a4
commit fde78a2c78
39 changed files with 300 additions and 923 deletions

View File

@@ -77,8 +77,6 @@ def Runner(
'serialisation': first_not_none(
attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']
),
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
}
)

View File

@@ -25,7 +25,6 @@ from openllm_core._typing_compat import (
TupleAny,
)
from openllm_core.exceptions import MissingDependencyError
from openllm_core.prompts import PromptTemplate
from openllm_core.utils import (
DEBUG,
ENV_VARS_TRUE_VALUES,
@@ -129,8 +128,6 @@ class LLM(t.Generic[M, T], ReprMixin):
_serialisation: LiteralSerialisation
_local: bool
_max_model_len: int | None
_prompt_template: PromptTemplate | None
_system_message: str | None
__llm_dtype__: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto'
__llm_torch_dtype__: 'torch.dtype' = None
@@ -148,8 +145,6 @@ class LLM(t.Generic[M, T], ReprMixin):
model_id,
model_version=None,
model_tag=None,
prompt_template=None,
system_message=None,
llm_config=None,
backend=None,
*args,
@@ -192,8 +187,6 @@ class LLM(t.Generic[M, T], ReprMixin):
serialisation=serialisation,
local=_local,
max_model_len=max_model_len,
prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template,
system_message=system_message,
LLM__model_attrs=model_attrs,
LLM__tokenizer_attrs=tokenizer_attrs,
llm_dtype__=dtype.lower(),
@@ -480,16 +473,24 @@ class LLM(t.Generic[M, T], ReprMixin):
request_id = gen_random_uuid() if request_id is None else request_id
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
async for out in self.runner.generate_iterator.async_stream(
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
):
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
delta_outputs = [None] * len(generated.outputs)
if generated.finished:
break
for output in generated.outputs:
i = output.index
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
yield generated.with_options(outputs=delta_outputs)
try:
generator = self.runner.generate_iterator.async_stream(
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
)
except Exception as err:
raise RuntimeError(f'Failed to start generation task: {err}') from err
try:
async for out in generator:
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
delta_outputs = [None] * len(generated.outputs)
if generated.finished:
break
for output in generated.outputs:
i = output.index
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
yield generated.with_options(outputs=delta_outputs)
except Exception as err:
raise RuntimeError(f'Exception caught during generation: {err}') from err

View File

@@ -18,7 +18,6 @@ from openllm_core._typing_compat import (
M,
T,
)
from openllm_core.prompts import PromptTemplate
from openllm_core.utils.representation import ReprArgs
from ._quantisation import QuantizationConfig
@@ -48,8 +47,6 @@ class LLM(Generic[M, T]):
_adapter_map: Optional[AdapterMap]
_serialisation: LiteralSerialisation
_local: bool
_prompt_template: Optional[PromptTemplate]
_system_message: Optional[str]
__llm_dtype__: Dtype = ...
__llm_torch_dtype__: Optional[torch.dtype] = ...
@@ -74,8 +71,6 @@ class LLM(Generic[M, T]):
model_id: str,
model_version: Optional[str] = ...,
model_tag: Optional[Union[str, Tag]] = ...,
prompt_template: Optional[Union[str, PromptTemplate]] = ...,
system_message: Optional[str] = ...,
llm_config: Optional[LLMConfig] = ...,
backend: Optional[LiteralBackend] = ...,
*args: Any,

View File

@@ -27,7 +27,7 @@ def registry(cls=None, *, alias=None):
return decorator(cls)
def runner(llm):
def runner(llm: openllm.LLM):
from ._strategies import CascadingResourceStrategy
try:
@@ -35,19 +35,6 @@ def runner(llm):
except bentoml.exceptions.NotFound as err:
raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err
if llm._prompt_template:
prompt_template = llm._prompt_template.to_string()
elif hasattr(llm.config, 'default_prompt_template'):
prompt_template = llm.config.default_prompt_template
else:
prompt_template = None
if llm._system_message:
system_message = llm._system_message
elif hasattr(llm.config, 'default_system_message'):
system_message = llm.config.default_system_message
else:
system_message = None
return types.new_class(
llm.config.__class__.__name__[:-6] + 'Runner',
(bentoml.Runner,),
@@ -80,8 +67,8 @@ def runner(llm):
('llm_tag', llm.tag),
),
'has_adapters': llm.has_adapters,
'prompt_template': prompt_template,
'system_message': system_message,
'template': llm.config.template,
'system_message': llm.config.system_message,
}
),
)(
@@ -101,7 +88,7 @@ class CTranslateRunnable(bentoml.Runnable):
def __init__(self, llm):
if not is_ctranslate_available():
raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`')
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.modle, llm.tokenizer
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):

View File

@@ -83,8 +83,8 @@ class Runner(Protocol[Mo, To]):
config: LLMConfig = ...
backend: LiteralBackend = ...
has_adapters: bool = ...
prompt_template: Optional[str] = ...
system_message: Optional[str] = ...
template: str = ...
system_message: str = ...
class generate_iterator(RunnerMethod[List[int], AsyncGenerator[str, None]]):
@staticmethod

View File

@@ -13,8 +13,6 @@ logger = logging.getLogger(__name__)
llm = openllm.LLM[t.Any, t.Any](
model_id=svars.model_id,
model_tag=svars.model_tag,
prompt_template=svars.prompt_template,
system_message=svars.system_message,
serialisation=svars.serialization,
adapter_map=svars.adapter_map,
trust_remote_code=svars.trust_remote_code,
@@ -50,8 +48,6 @@ _Metadata = openllm.MetadataOutput(
backend=llm.__llm_backend__,
model_id=llm.model_id,
configuration=llm.config.model_dump_json().decode(),
prompt_template=llm.runner.prompt_template,
system_message=llm.runner.system_message,
)

View File

@@ -7,7 +7,5 @@ from openllm_core.utils import ENV_VARS_TRUE_VALUES
model_id = os.environ['OPENLLM_MODEL_ID']
model_tag = None
adapter_map = orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None)))
prompt_template = os.getenv('OPENLLM_PROMPT_TEMPLATE')
system_message = os.getenv('OPENLLM_SYSTEM_MESSAGE')
serialization = os.getenv('OPENLLM_SERIALIZATION', default='safetensors')
trust_remote_code = str(os.getenv('TRUST_REMOTE_CODE', default=str(False))).upper() in ENV_VARS_TRUE_VALUES

View File

@@ -5,7 +5,5 @@ from openllm_core._typing_compat import LiteralSerialisation
model_id: str = ...
model_tag: Optional[str] = ...
adapter_map: Optional[Dict[str, str]] = ...
prompt_template: Optional[str] = ...
system_message: Optional[str] = ...
serialization: LiteralSerialisation = ...
trust_remote_code: bool = ...

View File

@@ -4,6 +4,4 @@ model_id = '{__model_id__}' # openllm: model id
model_tag = '{__model_tag__}' # openllm: model tag
adapter_map = orjson.loads("""{__model_adapter_map__}""") # openllm: model adapter map
serialization = '{__model_serialization__}' # openllm: model serialization
prompt_template = {__model_prompt_template__} # openllm: model prompt template
system_message = {__model_system_message__} # openllm: model system message
trust_remote_code = {__model_trust_remote_code__} # openllm: model trust remote code

View File

@@ -83,8 +83,6 @@ def construct_docker_options(
None,
llm._serialisation,
llm,
llm._system_message,
llm._prompt_template,
use_current_env=False,
)
# XXX: We need to quote this so that the envvar in container recognize as valid json
@@ -101,8 +99,6 @@ def construct_docker_options(
OPENLLM_MODEL_ID = '# openllm: model id'
OPENLLM_MODEL_TAG = '# openllm: model tag'
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
OPENLLM_MODEL_PROMPT_TEMPLATE = '# openllm: model prompt template'
OPENLLM_MODEL_SYSTEM_MESSAGE = '# openllm: model system message'
OPENLLM_MODEL_SERIALIZATION = '# openllm: model serialization'
OPENLLM_MODEL_TRUST_REMOTE_CODE = '# openllm: model trust remote code'
@@ -140,16 +136,6 @@ class ModelAdapterMapFormatter(_ServiceVarsFormatter):
identifier = OPENLLM_MODEL_ADAPTER_MAP
class ModelPromptTemplateFormatter(_ServiceVarsFormatter):
keyword = '__model_prompt_template__'
identifier = OPENLLM_MODEL_PROMPT_TEMPLATE
class ModelSystemMessageFormatter(_ServiceVarsFormatter):
keyword = '__model_system_message__'
identifier = OPENLLM_MODEL_SYSTEM_MESSAGE
class ModelSerializationFormatter(_ServiceVarsFormatter):
keyword = '__model_serialization__'
identifier = OPENLLM_MODEL_SERIALIZATION
@@ -187,16 +173,6 @@ def write_service(llm, llm_fs, adapter_map):
src_contents[i] = serialization_formatter.parse_line(it)
elif trust_remote_code_formatter.identifier in it:
src_contents[i] = trust_remote_code_formatter.parse_line(it)
elif OPENLLM_MODEL_PROMPT_TEMPLATE in it:
if llm._prompt_template:
src_contents[i] = ModelPromptTemplateFormatter(f'"""{llm._prompt_template.to_string()}"""').parse_line(it)
else:
src_contents[i] = ModelPromptTemplateFormatter(str(None)).parse_line(it)
elif OPENLLM_MODEL_SYSTEM_MESSAGE in it:
if llm._system_message:
src_contents[i] = ModelSystemMessageFormatter(f'"""{llm._system_message}"""').parse_line(it)
else:
src_contents[i] = ModelSystemMessageFormatter(str(None)).parse_line(it)
script = f"# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n\n" + ''.join(src_contents)
if SHOW_CODEGEN:

View File

@@ -105,7 +105,7 @@ async def cohere_generate(req, llm):
if err_check is not None:
return err_check
request_id = gen_random_uuid('cohere-generate')
config = llm.config.with_request(request)
config = llm.config.compatible_options(request)
if request.prompt_vars is not None:
prompt = request.prompt.format(**request.prompt_vars)
@@ -202,7 +202,7 @@ async def cohere_chat(req, llm):
_transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
)
logger.debug('Prompt: %r', prompt)
config = llm.config.with_request(request)
config = llm.config.compatible_options(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)

View File

@@ -156,7 +156,7 @@ async def chat_completions(req, llm):
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
)
logger.debug('Prompt: %r', prompt)
config = llm.config.with_request(request)
config = llm.config.compatible_options(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
@@ -272,14 +272,9 @@ async def completions(req, llm):
prompt = request.prompt
# TODO: Support multiple prompts
if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch
return error_response(
HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`."
)
model_name, request_id = request.model, gen_random_uuid('cmpl')
created_time = int(time.monotonic())
config = llm.config.with_request(request)
config = llm.config.compatible_options(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)