From b0f53e2007faf14d6bfc5a3bd34e97fcf1baa3ae Mon Sep 17 00:00:00 2001 From: bojiang Date: Tue, 21 May 2024 03:11:49 +0800 Subject: [PATCH] opt --- vllm-chat/service.py | 62 +++++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/vllm-chat/service.py b/vllm-chat/service.py index 4e7c15dd..8f575109 100644 --- a/vllm-chat/service.py +++ b/vllm-chat/service.py @@ -3,6 +3,7 @@ import json import os from typing import AsyncGenerator from typing_extensions import Annotated +import functools import bentoml from annotated_types import Ge, Le @@ -18,12 +19,22 @@ CHAT_TEMPLATE = CONSTANTS["chat_template"] SERVICE_CONFIG = CONSTANTS["service_config"] -CHAT_TEMPLATE_PATH = os.path.join( - os.path.dirname(__file__), "chat_templates", "chat_templates" -) -GEN_CONFIG_PATH = os.path.join( - os.path.dirname(__file__), "chat_templates", "generation_configs" -) +@functools.lru_cache(maxsize=1) +def _get_gen_config(): + import jinja2 + + chat_template_path = os.path.join( + os.path.dirname(__file__), "chat_templates", "chat_templates" + ) + config_path = os.path.join( + os.path.dirname(__file__), "chat_templates", "generation_configs" + ) + jinja_env = jinja2.Environment(loader=jinja2.FileSystemLoader(chat_template_path)) + with open(os.path.join(config_path, f"{CHAT_TEMPLATE}.json")) as f: + gen_config = json.load(f) + chat_template_file = gen_config["chat_template"].split("/")[-1] + gen_config["template"] = jinja_env.get_template(chat_template_file) + return gen_config @openai_endpoints( @@ -48,12 +59,14 @@ class VLLM: Le(ENGINE_CONFIG["max_model_len"]), ] = ENGINE_CONFIG["max_model_len"], stop: list[str] = [], + **kwargs, ) -> AsyncGenerator[str, None]: from vllm import SamplingParams SAMPLING_PARAM = SamplingParams( max_tokens=max_tokens, stop=stop, + **kwargs, ) stream = await self.engine.add_request(uuid.uuid4().hex, prompt, SAMPLING_PARAM) @@ -66,41 +79,48 @@ class VLLM: @bentoml.api async def chat( self, - messages: list[dict[str, str]], + messages: list[dict[str, str]] = [ + {"role": "user", "content": "What is the meaning of life?"} + ], + model: str = "", max_tokens: Annotated[ int, Ge(128), Le(ENGINE_CONFIG["max_model_len"]), ] = ENGINE_CONFIG["max_model_len"], + stop: list[str] | str | None = None, + stop_token_ids: list[int] | None = None, ) -> AsyncGenerator[str, None]: """ light-weight chat API that takes in a list of messages and returns a response """ - from vllm import SamplingParams - import jinja2 - JINJA_ENV = jinja2.Environment( - loader=jinja2.FileSystemLoader(CHAT_TEMPLATE_PATH) - ) + gen_config = _get_gen_config() - with open(os.path.join(GEN_CONFIG_PATH, f"{CHAT_TEMPLATE}.json")) as f: - gen_config = json.load(f) - chat_template_file = gen_config["chat_template"].split("/")[-1] - template = JINJA_ENV.get_template(chat_template_file) + if stop_token_ids is None: + stop_token_ids = gen_config["stop_token_ids"] + if stop == "" or stop is None: + if gen_config["stop_str"] is None: + stop = [] + else: + stop = [gen_config["stop_str"]] SAMPLING_PARAM = SamplingParams( max_tokens=max_tokens, - stop_token_ids=gen_config["stop_token_ids"], - stop=( - [gen_config["stop_str"]] if gen_config["stop_str"] is not None else [] - ), + stop_token_ids=stop_token_ids, + stop=stop, ) if gen_config["system_prompt"] and messages[0].get("role") != "system": messages = [ dict(role="system", content=gen_config["system_prompt"]) ] + messages - prompt = template.render(messages=messages, add_generation_prompt=True) + + prompt = gen_config["template"].render( + messages=messages, + add_generation_prompt=True, + ) + stream = await self.engine.add_request(uuid.uuid4().hex, prompt, SAMPLING_PARAM) cursor = 0