diff --git a/changelog.d/499.feature.md b/changelog.d/499.feature.md new file mode 100644 index 00000000..db9b1c39 --- /dev/null +++ b/changelog.d/499.feature.md @@ -0,0 +1 @@ +Add `/v1/models` endpoint for OpenAI compatible API diff --git a/examples/openai_client.py b/examples/openai_client.py index a8c61c66..4e3562a2 100644 --- a/examples/openai_client.py +++ b/examples/openai_client.py @@ -1,8 +1,10 @@ -import openai +import openai, os -openai.api_base = "http://localhost:3000/v1" +openai.api_base = os.getenv('OPENLLM_ENDPOINT', "http://localhost:3000") + '/v1' openai.api_key = "na" +print("Model:", openai.Model.list()) + response = openai.Completion.create(model="gpt-3.5-turbo-instruct", prompt="Write a tagline for an ice cream shop.", max_tokens=256) print(response) diff --git a/openllm-python/src/openllm/_service.py b/openllm-python/src/openllm/_service.py index 43a64720..88792f18 100644 --- a/openllm-python/src/openllm/_service.py +++ b/openllm-python/src/openllm/_service.py @@ -1,5 +1,6 @@ # mypy: disable-error-code="call-arg,misc,attr-defined,type-abstract,type-arg,valid-type,arg-type" from __future__ import annotations +import logging import typing as t import warnings @@ -23,6 +24,8 @@ warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast fro warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization') warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.') +logger = logging.getLogger(__name__) + model = svars.model model_id = svars.model_id adapter_map = svars.adapter_map @@ -62,6 +65,8 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.CompletionRequest(prompt='What is 1+1?', model=runner.llm_type))), output=bentoml.io.Text()) async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]: + _model = input_dict.get('model', None) + if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model) prompt = input_dict.pop('prompt', None) if prompt is None: raise ValueError("'prompt' should not be None.") stream = input_dict.pop('stream', False) @@ -118,6 +123,8 @@ async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> s }], model=runner.llm_type))), output=bentoml.io.Text()) async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]: + _model = input_dict.get('model', None) + if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model) prompt = openllm.openai.messages_to_prompt(input_dict['messages']) stream = input_dict.pop('stream', False) config = { @@ -158,12 +165,17 @@ async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) responses = await runner.generate.async_run(prompt, **config) return orjson.dumps( openllm.utils.bentoml_cattr.unstructure( - openllm.openai.ChatCompletionResponse(choices=[ - openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses) - ], - model=model) # TODO: logprobs, finish_reason and usage + openllm.openai.ChatCompletionResponse( + choices=[openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)], + model=runner.llm_type) # TODO: logprobs, finish_reason and usage )).decode('utf-8') +def models_v1(_: Request) -> Response: + return JSONResponse(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ModelList(data=[openllm.openai.ModelCard(id=runner.llm_type)])), status_code=200) + +openai_app = Starlette(debug=True, routes=[Route('/models', models_v1, methods=['GET'])]) +svc.mount_asgi_app(openai_app, path='/v1') + @svc.api(route='/v1/metadata', input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample({ diff --git a/openllm-python/src/openllm/protocol/openai.py b/openllm-python/src/openllm/protocol/openai.py index 2157a6a9..7f3e3783 100644 --- a/openllm-python/src/openllm/protocol/openai.py +++ b/openllm-python/src/openllm/protocol/openai.py @@ -119,6 +119,18 @@ class ChatCompletionResponseStream: id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl'))) created: int = attr.field(default=attr.Factory(lambda: int(time.time()))) +@attr.define +class ModelCard: + id: str + object: str = 'model' + created: int = attr.field(default=attr.Factory(lambda: int(time.time()))) + owned_by: str = 'na' + +@attr.define +class ModelList: + object: str = 'list' + data: t.List[ModelCard] = attr.field(factory=list) + def messages_to_prompt(messages: list[Message]) -> str: formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages]) return f'{formatted}\nassistant:'