This commit is contained in:
bojiang
2024-05-21 03:11:49 +08:00
parent 1860d9880e
commit b0f53e2007

View File

@@ -3,6 +3,7 @@ import json
import os
from typing import AsyncGenerator
from typing_extensions import Annotated
import functools
import bentoml
from annotated_types import Ge, Le
@@ -18,12 +19,22 @@ CHAT_TEMPLATE = CONSTANTS["chat_template"]
SERVICE_CONFIG = CONSTANTS["service_config"]
CHAT_TEMPLATE_PATH = os.path.join(
os.path.dirname(__file__), "chat_templates", "chat_templates"
)
GEN_CONFIG_PATH = os.path.join(
os.path.dirname(__file__), "chat_templates", "generation_configs"
)
@functools.lru_cache(maxsize=1)
def _get_gen_config():
import jinja2
chat_template_path = os.path.join(
os.path.dirname(__file__), "chat_templates", "chat_templates"
)
config_path = os.path.join(
os.path.dirname(__file__), "chat_templates", "generation_configs"
)
jinja_env = jinja2.Environment(loader=jinja2.FileSystemLoader(chat_template_path))
with open(os.path.join(config_path, f"{CHAT_TEMPLATE}.json")) as f:
gen_config = json.load(f)
chat_template_file = gen_config["chat_template"].split("/")[-1]
gen_config["template"] = jinja_env.get_template(chat_template_file)
return gen_config
@openai_endpoints(
@@ -48,12 +59,14 @@ class VLLM:
Le(ENGINE_CONFIG["max_model_len"]),
] = ENGINE_CONFIG["max_model_len"],
stop: list[str] = [],
**kwargs,
) -> AsyncGenerator[str, None]:
from vllm import SamplingParams
SAMPLING_PARAM = SamplingParams(
max_tokens=max_tokens,
stop=stop,
**kwargs,
)
stream = await self.engine.add_request(uuid.uuid4().hex, prompt, SAMPLING_PARAM)
@@ -66,41 +79,48 @@ class VLLM:
@bentoml.api
async def chat(
self,
messages: list[dict[str, str]],
messages: list[dict[str, str]] = [
{"role": "user", "content": "What is the meaning of life?"}
],
model: str = "",
max_tokens: Annotated[
int,
Ge(128),
Le(ENGINE_CONFIG["max_model_len"]),
] = ENGINE_CONFIG["max_model_len"],
stop: list[str] | str | None = None,
stop_token_ids: list[int] | None = None,
) -> AsyncGenerator[str, None]:
"""
light-weight chat API that takes in a list of messages and returns a response
"""
from vllm import SamplingParams
import jinja2
JINJA_ENV = jinja2.Environment(
loader=jinja2.FileSystemLoader(CHAT_TEMPLATE_PATH)
)
gen_config = _get_gen_config()
with open(os.path.join(GEN_CONFIG_PATH, f"{CHAT_TEMPLATE}.json")) as f:
gen_config = json.load(f)
chat_template_file = gen_config["chat_template"].split("/")[-1]
template = JINJA_ENV.get_template(chat_template_file)
if stop_token_ids is None:
stop_token_ids = gen_config["stop_token_ids"]
if stop == "" or stop is None:
if gen_config["stop_str"] is None:
stop = []
else:
stop = [gen_config["stop_str"]]
SAMPLING_PARAM = SamplingParams(
max_tokens=max_tokens,
stop_token_ids=gen_config["stop_token_ids"],
stop=(
[gen_config["stop_str"]] if gen_config["stop_str"] is not None else []
),
stop_token_ids=stop_token_ids,
stop=stop,
)
if gen_config["system_prompt"] and messages[0].get("role") != "system":
messages = [
dict(role="system", content=gen_config["system_prompt"])
] + messages
prompt = template.render(messages=messages, add_generation_prompt=True)
prompt = gen_config["template"].render(
messages=messages,
add_generation_prompt=True,
)
stream = await self.engine.add_request(uuid.uuid4().hex, prompt, SAMPLING_PARAM)
cursor = 0