feat: openai.Model.list() (#499)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
XunchaoZ
2023-10-14 16:33:49 -04:00
committed by GitHub
parent c1ca7ccd3b
commit d9183267dc
4 changed files with 33 additions and 6 deletions

View File

@@ -1,5 +1,6 @@
# mypy: disable-error-code="call-arg,misc,attr-defined,type-abstract,type-arg,valid-type,arg-type"
from __future__ import annotations
import logging
import typing as t
import warnings
@@ -23,6 +24,8 @@ warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast fro
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')
logger = logging.getLogger(__name__)
model = svars.model
model_id = svars.model_id
adapter_map = svars.adapter_map
@@ -62,6 +65,8 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.CompletionRequest(prompt='What is 1+1?', model=runner.llm_type))),
output=bentoml.io.Text())
async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
_model = input_dict.get('model', None)
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
prompt = input_dict.pop('prompt', None)
if prompt is None: raise ValueError("'prompt' should not be None.")
stream = input_dict.pop('stream', False)
@@ -118,6 +123,8 @@ async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> s
}], model=runner.llm_type))),
output=bentoml.io.Text())
async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
_model = input_dict.get('model', None)
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
prompt = openllm.openai.messages_to_prompt(input_dict['messages'])
stream = input_dict.pop('stream', False)
config = {
@@ -158,12 +165,17 @@ async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context)
responses = await runner.generate.async_run(prompt, **config)
return orjson.dumps(
openllm.utils.bentoml_cattr.unstructure(
openllm.openai.ChatCompletionResponse(choices=[
openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)
],
model=model) # TODO: logprobs, finish_reason and usage
openllm.openai.ChatCompletionResponse(
choices=[openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)],
model=runner.llm_type) # TODO: logprobs, finish_reason and usage
)).decode('utf-8')
def models_v1(_: Request) -> Response:
return JSONResponse(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ModelList(data=[openllm.openai.ModelCard(id=runner.llm_type)])), status_code=200)
openai_app = Starlette(debug=True, routes=[Route('/models', models_v1, methods=['GET'])])
svc.mount_asgi_app(openai_app, path='/v1')
@svc.api(route='/v1/metadata',
input=bentoml.io.Text(),
output=bentoml.io.JSON.from_sample({

View File

@@ -119,6 +119,18 @@ class ChatCompletionResponseStream:
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
@attr.define
class ModelCard:
id: str
object: str = 'model'
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
owned_by: str = 'na'
@attr.define
class ModelList:
object: str = 'list'
data: t.List[ModelCard] = attr.field(factory=list)
def messages_to_prompt(messages: list[Message]) -> str:
formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages])
return f'{formatted}\nassistant:'