diff --git a/examples/openai_client.py b/examples/openai_client.py new file mode 100644 index 00000000..86098807 --- /dev/null +++ b/examples/openai_client.py @@ -0,0 +1,39 @@ +import openai + +openai.api_base = "http://localhost:3000/v1" +openai.api_key = "na" + +response = openai.Completion.create(model="gpt-3.5-turbo-instruct", prompt="Write a tagline for an ice cream shop.", max_tokens=256) + +print(response) + +for chunk in openai.Completion.create( + model="gpt-3.5-turbo-instruct", + prompt="Say this is a test", + max_tokens=7, + temperature=0, + stream=True +): + print(chunk) + + +completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ] +) + +print(completion) + +completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ], + stream=True +) + +for chunk in completion: print(chunk) diff --git a/openllm-core/src/openllm_core/__init__.py b/openllm-core/src/openllm_core/__init__.py index 36ed04ef..16c6b9ba 100644 --- a/openllm-core/src/openllm_core/__init__.py +++ b/openllm-core/src/openllm_core/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from . import exceptions as exceptions +from . import prompts as prompts from . import utils as utils from ._configuration import GenerationConfig as GenerationConfig from ._configuration import LLMConfig as LLMConfig diff --git a/openllm-python/pyproject.toml b/openllm-python/pyproject.toml index 47d78447..cfce5cdc 100644 --- a/openllm-python/pyproject.toml +++ b/openllm-python/pyproject.toml @@ -111,7 +111,7 @@ gptq = ["auto-gptq[triton]>=0.4.2", "optimum>=1.12.0"] grpc = ["openllm-client[grpc]"] llama = ["fairscale", "sentencepiece", "scipy"] mpt = ["triton", "einops"] -openai = ["openai", "tiktoken"] +openai = ["openai[embeddings]", "tiktoken"] opt = ["flax>=0.7", "jax", "jaxlib", "tensorflow", "keras"] playground = ["jupyter", "notebook", "ipython", "jupytext", "nbformat"] starcoder = ["bitsandbytes"] diff --git a/openllm-python/src/openllm/__init__.py b/openllm-python/src/openllm/__init__.py index 16d7e5e2..92b4c682 100644 --- a/openllm-python/src/openllm/__init__.py +++ b/openllm-python/src/openllm/__init__.py @@ -40,6 +40,7 @@ _import_structure: dict[str, list[str]] = { "playground": [], "testing": [], "prompts": ["PromptTemplate"], + "protocol": ["openai"], "utils": ["infer_auto_class"], "serialisation": ["ggml", "transformers"], "cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"], @@ -72,6 +73,7 @@ if _t.TYPE_CHECKING: from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES from .serialisation import ggml as ggml, transformers as transformers from .prompts import PromptTemplate as PromptTemplate + from .protocol import openai as openai from .utils import infer_auto_class as infer_auto_class try: diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index bd606acd..8dcf8b1a 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -1241,6 +1241,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate stop: str | t.Iterable[str] | None = attrs.pop('stop', None) echo = attrs.pop('echo', False) stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None) + temperature = attrs.pop('temperature', self.config['temperature']) + top_p = attrs.pop('top_p', self.config['top_p']) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) request_id: str | None = attrs.pop('request_id', None) @@ -1254,8 +1256,7 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate for tid in stop_token_ids: if tid: stop_.add(self.tokenizer.decode(tid)) - if self.config['temperature'] <= 1e-5: top_p = 1.0 - else: top_p = self.config['top_p'] + if temperature <= 1e-5: top_p = 1.0 config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs) sampling_params = config.to_sampling_config() @@ -1276,6 +1277,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate echo = attrs.pop('echo', False) stop: str | t.Iterable[str] | None = attrs.pop('stop', None) stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None) + temperature = attrs.pop('temperature', self.config['temperature']) + top_p = attrs.pop('top_p', self.config['top_p']) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) request_id: str | None = attrs.pop('request_id', None) @@ -1289,8 +1292,7 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate for tid in stop_token_ids: if tid: stop_.add(self.tokenizer.decode(tid)) - if self.config['temperature'] <= 1e-5: top_p = 1.0 - else: top_p = self.config['top_p'] + if temperature <= 1e-5: top_p = 1.0 config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs) sampling_params = config.to_sampling_config() async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id): diff --git a/openllm-python/src/openllm/_service.py b/openllm-python/src/openllm/_service.py index 1f538075..94d04fa9 100644 --- a/openllm-python/src/openllm/_service.py +++ b/openllm-python/src/openllm/_service.py @@ -69,6 +69,107 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s else: return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, **qa_inputs.llm_config.model_dump()) +@svc.api(route='v1/completions', + input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.CompletionRequest(prompt='What is 1+1?', model=runner.llm_type))), + output=bentoml.io.Text()) +async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]: + prompt = input_dict.pop('prompt', None) + if prompt is None: raise ValueError("'prompt' should not be None.") + stream = input_dict.pop('stream', False) + config = { + 'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']), + 'temperature': input_dict.pop('temperature', llm_config['temperature']), + 'top_p': input_dict.pop('top_p', llm_config['top_p']), + 'n': input_dict.pop('n', llm_config['n']), + 'logprobs': input_dict.pop('logprobs', llm_config['logprobs']), + 'echo': input_dict.pop('echo', False), + 'stop': input_dict.pop('stop', llm_config['stop']), + 'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']), + 'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']), + 'best_of': input_dict.pop('best_of', llm_config['best_of']), + } + + async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]: + async for response in responses: + st = openllm.openai.CompletionResponseStream(choices=[openllm.openai.CompletionTextChoice(text=response, index=0)], model=runner.llm_type) # TODO: logprobs, finish_reason + yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(st)).decode()}\n\n' + yield 'data: [DONE]\n\n' + + if stream: + ctx.response.headers['Content-Type'] = 'text/event-stream' + if runner.backend == 'vllm': + responses = runner.vllm_generate_iterator.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config) + else: + responses = runner.generate_iterator.async_stream(prompt, **config) + return stream_response_generator(responses) + else: + ctx.response.headers['Content-Type'] = 'application/json' + if runner.backend == 'vllm': + async for output in runner.vllm_generate.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config): + responses = output + if responses is None: raise ValueError("'responses' should not be None.") + else: + responses = await runner.generate.async_run(prompt, **config) + + return orjson.dumps( + openllm.utils.bentoml_cattr.unstructure( + openllm.openai.CompletionResponse(choices=[openllm.openai.CompletionTextChoice(text=response, index=i) for i, response in enumerate(responses)], + model=runner.llm_type) # TODO: logprobs, finish_reason and usage + )).decode() + +@svc.api(route='/v1/chat/completions', + input=bentoml.io.JSON.from_sample( + openllm.utils.bentoml_cattr.unstructure( + openllm.openai.ChatCompletionRequest(messages=[{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello!'}], + model=runner.llm_type))), + output=bentoml.io.Text()) +async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]: + prompt = openllm.openai.messages_to_prompt(input_dict['messages']) + stream = input_dict.pop('stream', False) + config = { + 'temperature': input_dict.pop('temperature', llm_config['temperature']), + 'top_p': input_dict.pop('top_p', llm_config['top_p']), + 'n': input_dict.pop('n', llm_config['n']), + 'echo': input_dict.pop('echo', False), + 'stop': input_dict.pop('stop', llm_config['stop']), + 'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']), + 'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']), + 'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']), + } + + async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]: + async for response in responses: + st = openllm.openai.ChatCompletionResponseStream( + choices=[openllm.openai.ChatCompletionStreamChoice(index=0, delta=openllm.openai.Message(role='assistant', content=response), finish_reason=None)], model=runner.llm_type) + yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(st)).decode()}\n\n' + final = openllm.openai.ChatCompletionResponseStream( + choices=[openllm.openai.ChatCompletionStreamChoice(index=0, delta=openllm.openai.Message(role='assistant', content=''), finish_reason='stop')], model=runner.llm_type) + yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(final)).decode()}\n\n' + yield 'data: [DONE]\n\n' + + if stream: + ctx.response.headers['Content-Type'] = 'text/event-stream' + if runner.backend == 'vllm': + responses = runner.vllm_generate_iterator.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config) + else: + responses = runner.generate_iterator.async_stream(prompt, **config) + return stream_response_generator(responses) + else: + ctx.response.headers['Content-Type'] = 'application/json' + if runner.backend == 'vllm': + async for output in runner.vllm_generate.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config): + responses = output + if responses is None: raise ValueError("'responses' should not be None.") + else: + responses = await runner.generate.async_run(prompt, **config) + return orjson.dumps( + openllm.utils.bentoml_cattr.unstructure( + openllm.openai.ChatCompletionResponse(choices=[ + openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses) + ], + model=model) # TODO: logprobs, finish_reason and usage + )).decode('utf-8') + @svc.api(route='/v1/metadata', input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample({ diff --git a/openllm-python/src/openllm/models/opt/modeling_flax_opt.py b/openllm-python/src/openllm/models/opt/modeling_flax_opt.py index f8757121..7b2bc981 100644 --- a/openllm-python/src/openllm/models/opt/modeling_flax_opt.py +++ b/openllm-python/src/openllm/models/opt/modeling_flax_opt.py @@ -4,7 +4,7 @@ import typing as t import bentoml import openllm -from openllm._prompt import process_prompt +from openllm_core.prompts import process_prompt from openllm.utils import generate_labels from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE if t.TYPE_CHECKING: import transformers diff --git a/openllm-python/src/openllm/models/opt/modeling_vllm_opt.py b/openllm-python/src/openllm/models/opt/modeling_vllm_opt.py index fec0a85c..6c677ac3 100644 --- a/openllm-python/src/openllm/models/opt/modeling_vllm_opt.py +++ b/openllm-python/src/openllm/models/opt/modeling_vllm_opt.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t import openllm -from openllm_core._prompt import process_prompt +from openllm_core.prompts import process_prompt from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE if t.TYPE_CHECKING: import vllm, transformers diff --git a/openllm-python/src/openllm/protocol/__init__.py b/openllm-python/src/openllm/protocol/__init__.py new file mode 100644 index 00000000..bee13ece --- /dev/null +++ b/openllm-python/src/openllm/protocol/__init__.py @@ -0,0 +1,19 @@ +'''Protocol-related packages for all library integrations. + +Currently support OpenAI compatible API. +''' +from __future__ import annotations +import os +import typing as t + +from openllm_core.utils import LazyModule + +_import_structure: dict[str, list[str]] = {'openai': []} + +if t.TYPE_CHECKING: + from . import openai as openai + +__lazy = LazyModule(__name__, os.path.abspath('__file__'), _import_structure) +__all__ = __lazy.__all__ +__dir__ = __lazy.__dir__ +__getattr__ = __lazy.__getattr__ diff --git a/openllm-python/src/openllm/protocol/openai.py b/openllm-python/src/openllm/protocol/openai.py new file mode 100644 index 00000000..2157a6a9 --- /dev/null +++ b/openllm-python/src/openllm/protocol/openai.py @@ -0,0 +1,124 @@ +from __future__ import annotations +import time +import typing as t + +import attr + +import openllm_core + +@attr.define +class CompletionRequest: + prompt: str + model: str = attr.field(default=None) + suffix: t.Optional[str] = attr.field(default=None) + max_tokens: t.Optional[int] = attr.field(default=16) + temperature: t.Optional[float] = attr.field(default=1.0) + top_p: t.Optional[float] = attr.field(default=1) + n: t.Optional[int] = attr.field(default=1) + stream: t.Optional[bool] = attr.field(default=False) + logprobs: t.Optional[int] = attr.field(default=None) + echo: t.Optional[bool] = attr.field(default=False) + stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None) + presence_penalty: t.Optional[float] = attr.field(default=0.0) + frequency_penalty: t.Optional[float] = attr.field(default=0.0) + best_of: t.Optional[int] = attr.field(default=1) + logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None) + user: t.Optional[str] = attr.field(default=None) + +@attr.define +class ChatCompletionRequest: + messages: t.List[t.Dict[str, str]] + model: str = attr.field(default=None) + functions: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list)) + function_calls: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list)) + temperature: t.Optional[float] = attr.field(default=1.0) + top_p: t.Optional[float] = attr.field(default=1) + n: t.Optional[int] = attr.field(default=1) + stream: t.Optional[bool] = attr.field(default=False) + stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None) + max_tokens: t.Optional[int] = attr.field(default=None) + presence_penalty: t.Optional[float] = attr.field(default=0.0) + frequency_penalty: t.Optional[float] = attr.field(default=0.0) + logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None) + user: t.Optional[str] = attr.field(default=None) + +@attr.define +class LogProbs: + text_offset: t.List[int] = attr.field(default=attr.Factory(list)) + token_logprobs: t.List[float] = attr.field(default=attr.Factory(list)) + tokens: t.List[str] = attr.field(default=attr.Factory(list)) + top_logprobs: t.List[t.Dict[str, t.Any]] = attr.field(default=attr.Factory(list)) + +@attr.define +class CompletionTextChoice: + text: str + index: int + logprobs: LogProbs = attr.field(default=attr.Factory(lambda: LogProbs())) + finish_reason: str = attr.field(default=None) + +@attr.define +class Usage: + prompt_tokens: int = attr.field(default=0) + completion_tokens: int = attr.field(default=0) + total_tokens: int = attr.field(default=0) + +@attr.define +class CompletionResponse: + choices: t.List[CompletionTextChoice] + model: str + object: str = 'text_completion' + id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl'))) + created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic()))) + usage: Usage = attr.field(default=attr.Factory(lambda: Usage())) + +@attr.define +class CompletionResponseStream: + choices: t.List[CompletionTextChoice] + model: str + object: str = 'text_completion' + id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl'))) + created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic()))) + +LiteralRole = t.Literal['system', 'user', 'assistant'] + +class Message(t.TypedDict): + role: LiteralRole + content: str + +@attr.define +class Delta: + role: LiteralRole + content: str + +@attr.define +class ChatCompletionChoice: + index: int + message: Message + finish_reason: str = attr.field(default=None) + +@attr.define +class ChatCompletionStreamChoice: + index: int + delta: Message + finish_reason: str = attr.field(default=None) + +@attr.define +class ChatCompletionResponse: + choices: t.List[ChatCompletionChoice] + model: str + object: str = 'chat.completion' + id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl'))) + created: int = attr.field(default=attr.Factory(lambda: int(time.time()))) + usage: Usage = attr.field(default=attr.Factory(lambda: Usage())) + +@attr.define +class ChatCompletionResponseStream: + choices: t.List[ChatCompletionStreamChoice] + model: str + object: str = 'chat.completion.chunk' + id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl'))) + created: int = attr.field(default=attr.Factory(lambda: int(time.time()))) + +def messages_to_prompt(messages: list[Message]) -> str: + formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages]) + return f'{formatted}\nassistant:' diff --git a/tools/dependencies.py b/tools/dependencies.py index 21a85ddc..37d6f873 100755 --- a/tools/dependencies.py +++ b/tools/dependencies.py @@ -128,7 +128,7 @@ FINE_TUNE_DEPS = ['peft>=0.5.0', 'bitsandbytes', 'datasets', 'accelerate', 'trl' FLAN_T5_DEPS = _ALL_RUNTIME_DEPS OPT_DEPS = _ALL_RUNTIME_DEPS GRPC_DEPS = ['openllm-client[grpc]'] -OPENAI_DEPS = ['openai', 'tiktoken'] +OPENAI_DEPS = ['openai[embeddings]', 'tiktoken'] AGENTS_DEPS = ['transformers[agents]>=4.30', 'diffusers', 'soundfile'] PLAYGROUND_DEPS = ['jupyter', 'notebook', 'ipython', 'jupytext', 'nbformat'] GGML_DEPS = ['ctransformers']