Merge pull request #1 from bentoml/rick-0604-pr-chat-template

feat: Use community chat template as the source of truth. Fall back to HF tokenizer template
This commit is contained in:
Rick Zhou
2024-06-09 16:03:03 -07:00
committed by GitHub
2 changed files with 18 additions and 10 deletions

View File

@@ -88,6 +88,7 @@ def openai_endpoints(
if self.chat_template is None and chat_template_model_id is not None:
from transformers import AutoTokenizer
# If no community chat template is provided, use the tokenizer's chat template
_tokenizer = AutoTokenizer.from_pretrained(chat_template_model_id)
self.chat_template = _tokenizer.chat_template

View File

@@ -1,25 +1,30 @@
import uuid
import json
import os
from typing import AsyncGenerator, Union
from typing_extensions import Annotated
import functools
import json
import logging
import os
import uuid
from typing import AsyncGenerator, Union
import bentoml
from annotated_types import Ge, Le
from bentovllm_openai.utils import openai_endpoints
import yaml
from annotated_types import Ge, Le
from bento_constants import CONSTANT_YAML
from bentovllm_openai.utils import openai_endpoints
from typing_extensions import Annotated
CONSTANTS = yaml.safe_load(CONSTANT_YAML)
ENGINE_CONFIG = CONSTANTS["engine_config"]
SERVICE_CONFIG = CONSTANTS["service_config"]
CHAT_TEMPLATE = CONSTANTS.get("chat_template")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@functools.lru_cache(maxsize=1)
def _get_gen_config(community_chat_template: str) -> dict:
logger.info(f"Load community_chat_template: {community_chat_template}")
chat_template_path = os.path.join(
os.path.dirname(__file__), "chat_templates", "chat_templates"
)
@@ -37,17 +42,19 @@ def _get_gen_config(community_chat_template: str) -> dict:
@openai_endpoints(
served_model_names=[ENGINE_CONFIG["model"]],
chat_template=_get_gen_config(CHAT_TEMPLATE)["template"] if CHAT_TEMPLATE else None,
chat_template_model_id=ENGINE_CONFIG["model"],
)
@bentoml.service(**SERVICE_CONFIG)
class VLLM:
def __init__(self) -> None:
from vllm import AsyncEngineArgs, AsyncLLMEngine
from transformers import AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine
ENGINE_ARGS = AsyncEngineArgs(**ENGINE_CONFIG)
self.engine = AsyncLLMEngine.from_engine_args(ENGINE_ARGS)
self.tokenizer = AutoTokenizer.from_pretrained(ENGINE_CONFIG["model"])
logger.info(f"VLLM service initialized with model: {ENGINE_CONFIG['model']}")
@bentoml.api
async def generate(
@@ -94,7 +101,7 @@ class VLLM:
"""
from vllm import SamplingParams
if CONSTANTS.get("chat_template"): # community chat template
if CHAT_TEMPLATE: # community chat template
gen_config = _get_gen_config(CONSTANTS["chat_template"])
if not stop:
if gen_config["stop_str"]: