mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-21 15:39:36 -04:00
feat: openai.Model.list() (#499)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# mypy: disable-error-code="call-arg,misc,attr-defined,type-abstract,type-arg,valid-type,arg-type"
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
import warnings
|
||||
|
||||
@@ -23,6 +24,8 @@ warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast fro
|
||||
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
||||
warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
model = svars.model
|
||||
model_id = svars.model_id
|
||||
adapter_map = svars.adapter_map
|
||||
@@ -62,6 +65,8 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
|
||||
input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.CompletionRequest(prompt='What is 1+1?', model=runner.llm_type))),
|
||||
output=bentoml.io.Text())
|
||||
async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
|
||||
_model = input_dict.get('model', None)
|
||||
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
|
||||
prompt = input_dict.pop('prompt', None)
|
||||
if prompt is None: raise ValueError("'prompt' should not be None.")
|
||||
stream = input_dict.pop('stream', False)
|
||||
@@ -118,6 +123,8 @@ async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> s
|
||||
}], model=runner.llm_type))),
|
||||
output=bentoml.io.Text())
|
||||
async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
|
||||
_model = input_dict.get('model', None)
|
||||
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
|
||||
prompt = openllm.openai.messages_to_prompt(input_dict['messages'])
|
||||
stream = input_dict.pop('stream', False)
|
||||
config = {
|
||||
@@ -158,12 +165,17 @@ async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context)
|
||||
responses = await runner.generate.async_run(prompt, **config)
|
||||
return orjson.dumps(
|
||||
openllm.utils.bentoml_cattr.unstructure(
|
||||
openllm.openai.ChatCompletionResponse(choices=[
|
||||
openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)
|
||||
],
|
||||
model=model) # TODO: logprobs, finish_reason and usage
|
||||
openllm.openai.ChatCompletionResponse(
|
||||
choices=[openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)],
|
||||
model=runner.llm_type) # TODO: logprobs, finish_reason and usage
|
||||
)).decode('utf-8')
|
||||
|
||||
def models_v1(_: Request) -> Response:
|
||||
return JSONResponse(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ModelList(data=[openllm.openai.ModelCard(id=runner.llm_type)])), status_code=200)
|
||||
|
||||
openai_app = Starlette(debug=True, routes=[Route('/models', models_v1, methods=['GET'])])
|
||||
svc.mount_asgi_app(openai_app, path='/v1')
|
||||
|
||||
@svc.api(route='/v1/metadata',
|
||||
input=bentoml.io.Text(),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
|
||||
@@ -119,6 +119,18 @@ class ChatCompletionResponseStream:
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
|
||||
|
||||
@attr.define
|
||||
class ModelCard:
|
||||
id: str
|
||||
object: str = 'model'
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
|
||||
owned_by: str = 'na'
|
||||
|
||||
@attr.define
|
||||
class ModelList:
|
||||
object: str = 'list'
|
||||
data: t.List[ModelCard] = attr.field(factory=list)
|
||||
|
||||
def messages_to_prompt(messages: list[Message]) -> str:
|
||||
formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages])
|
||||
return f'{formatted}\nassistant:'
|
||||
|
||||
Reference in New Issue
Block a user