mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-11 18:09:52 -04:00
opt
This commit is contained in:
@@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing_extensions import Annotated
|
||||
import functools
|
||||
|
||||
import bentoml
|
||||
from annotated_types import Ge, Le
|
||||
@@ -18,12 +19,22 @@ CHAT_TEMPLATE = CONSTANTS["chat_template"]
|
||||
SERVICE_CONFIG = CONSTANTS["service_config"]
|
||||
|
||||
|
||||
CHAT_TEMPLATE_PATH = os.path.join(
|
||||
os.path.dirname(__file__), "chat_templates", "chat_templates"
|
||||
)
|
||||
GEN_CONFIG_PATH = os.path.join(
|
||||
os.path.dirname(__file__), "chat_templates", "generation_configs"
|
||||
)
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_gen_config():
|
||||
import jinja2
|
||||
|
||||
chat_template_path = os.path.join(
|
||||
os.path.dirname(__file__), "chat_templates", "chat_templates"
|
||||
)
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(__file__), "chat_templates", "generation_configs"
|
||||
)
|
||||
jinja_env = jinja2.Environment(loader=jinja2.FileSystemLoader(chat_template_path))
|
||||
with open(os.path.join(config_path, f"{CHAT_TEMPLATE}.json")) as f:
|
||||
gen_config = json.load(f)
|
||||
chat_template_file = gen_config["chat_template"].split("/")[-1]
|
||||
gen_config["template"] = jinja_env.get_template(chat_template_file)
|
||||
return gen_config
|
||||
|
||||
|
||||
@openai_endpoints(
|
||||
@@ -48,12 +59,14 @@ class VLLM:
|
||||
Le(ENGINE_CONFIG["max_model_len"]),
|
||||
] = ENGINE_CONFIG["max_model_len"],
|
||||
stop: list[str] = [],
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
from vllm import SamplingParams
|
||||
|
||||
SAMPLING_PARAM = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
stop=stop,
|
||||
**kwargs,
|
||||
)
|
||||
stream = await self.engine.add_request(uuid.uuid4().hex, prompt, SAMPLING_PARAM)
|
||||
|
||||
@@ -66,41 +79,48 @@ class VLLM:
|
||||
@bentoml.api
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
messages: list[dict[str, str]] = [
|
||||
{"role": "user", "content": "What is the meaning of life?"}
|
||||
],
|
||||
model: str = "",
|
||||
max_tokens: Annotated[
|
||||
int,
|
||||
Ge(128),
|
||||
Le(ENGINE_CONFIG["max_model_len"]),
|
||||
] = ENGINE_CONFIG["max_model_len"],
|
||||
stop: list[str] | str | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
light-weight chat API that takes in a list of messages and returns a response
|
||||
"""
|
||||
|
||||
from vllm import SamplingParams
|
||||
import jinja2
|
||||
|
||||
JINJA_ENV = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(CHAT_TEMPLATE_PATH)
|
||||
)
|
||||
gen_config = _get_gen_config()
|
||||
|
||||
with open(os.path.join(GEN_CONFIG_PATH, f"{CHAT_TEMPLATE}.json")) as f:
|
||||
gen_config = json.load(f)
|
||||
chat_template_file = gen_config["chat_template"].split("/")[-1]
|
||||
template = JINJA_ENV.get_template(chat_template_file)
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = gen_config["stop_token_ids"]
|
||||
if stop == "" or stop is None:
|
||||
if gen_config["stop_str"] is None:
|
||||
stop = []
|
||||
else:
|
||||
stop = [gen_config["stop_str"]]
|
||||
|
||||
SAMPLING_PARAM = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=gen_config["stop_token_ids"],
|
||||
stop=(
|
||||
[gen_config["stop_str"]] if gen_config["stop_str"] is not None else []
|
||||
),
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop=stop,
|
||||
)
|
||||
if gen_config["system_prompt"] and messages[0].get("role") != "system":
|
||||
messages = [
|
||||
dict(role="system", content=gen_config["system_prompt"])
|
||||
] + messages
|
||||
prompt = template.render(messages=messages, add_generation_prompt=True)
|
||||
|
||||
prompt = gen_config["template"].render(
|
||||
messages=messages,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
stream = await self.engine.add_request(uuid.uuid4().hex, prompt, SAMPLING_PARAM)
|
||||
|
||||
cursor = 0
|
||||
|
||||
Reference in New Issue
Block a user