mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-24 15:27:51 -05:00
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
58 lines
3.3 KiB
Python
58 lines
3.3 KiB
Python
# mypy: disable-error-code="call-arg,misc,attr-defined,type-abstract,type-arg,valid-type,arg-type"
|
|
from __future__ import annotations
|
|
import logging
|
|
import os
|
|
import typing as t
|
|
import warnings
|
|
|
|
import _service_vars as svars
|
|
import orjson
|
|
|
|
import bentoml
|
|
import openllm
|
|
|
|
# The following warnings from bitsandbytes, and probably not that important for users to see
|
|
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
|
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
|
warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
model = svars.model
|
|
model_id = svars.model_id
|
|
adapter_map = svars.adapter_map
|
|
model_tag = svars.model_tag
|
|
llm_config = openllm.AutoConfig.for_model(model)
|
|
llm = openllm.LLM[t.Any, t.Any](model_id,
|
|
llm_config=llm_config,
|
|
model_tag=model_tag,
|
|
prompt_template=openllm.utils.first_not_none(os.getenv('OPENLLM_PROMPT_TEMPLATE'), getattr(llm_config, 'default_prompt_template', None)),
|
|
system_message=openllm.utils.first_not_none(os.getenv('OPENLLM_SYSTEM_MESSAGE'), getattr(llm_config, 'default_system_message', None)),
|
|
serialisation=openllm.utils.first_not_none(os.getenv('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']),
|
|
adapter_map=orjson.loads(adapter_map))
|
|
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[llm.runner])
|
|
|
|
llm_model_class = openllm.GenerationInput.from_llm_config(llm_config)
|
|
|
|
@svc.api(route='/v1/generate', input=bentoml.io.JSON.from_sample(llm_model_class.examples().model_dump()), output=bentoml.io.JSON.from_sample(openllm.GenerationOutput.examples().model_dump()))
|
|
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
|
return await llm.generate(**llm_model_class(**input_dict).model_dump())
|
|
|
|
@svc.api(route='/v1/generate_stream', input=bentoml.io.JSON.from_sample(llm_model_class.examples().model_dump()), output=bentoml.io.Text(content_type='text/event-stream'))
|
|
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
|
async for it in llm.generate_iterator(**llm_model_class(**input_dict).model_dump()):
|
|
yield f'data: {it.model_dump_json()}\n\n'
|
|
yield 'data: [DONE]\n\n'
|
|
|
|
@svc.api(route='/v1/metadata', input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample(openllm.MetadataOutput.examples(llm).model_dump()))
|
|
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
|
return openllm.MetadataOutput(timeout=llm_config['timeout'],
|
|
model_name=llm_config['model_name'],
|
|
backend=llm.__llm_backend__,
|
|
model_id=llm.model_id,
|
|
configuration=llm_config.model_dump_json().decode(),
|
|
prompt_template=llm.runner.prompt_template,
|
|
system_message=llm.runner.system_message)
|
|
|
|
openllm.mount_entrypoints(svc, llm) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema.
|