From cbde63ab2427ad8c12d5e3dbe20a0c7e92d422e2 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Tue, 4 Jun 2024 21:21:12 +0000 Subject: [PATCH] feat: Use community chat template as the source of truth. Fall back to HF tokenizer template --- vllm-chat/bentovllm_openai/utils.py | 1 + vllm-chat/service.py | 27 +++++++++++++++++---------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/vllm-chat/bentovllm_openai/utils.py b/vllm-chat/bentovllm_openai/utils.py index 88f1dd61..38ab85bd 100644 --- a/vllm-chat/bentovllm_openai/utils.py +++ b/vllm-chat/bentovllm_openai/utils.py @@ -88,6 +88,7 @@ def openai_endpoints( if self.chat_template is None and chat_template_model_id is not None: from transformers import AutoTokenizer + # If no community chat template is provided, use the tokenizer's chat template _tokenizer = AutoTokenizer.from_pretrained(chat_template_model_id) self.chat_template = _tokenizer.chat_template diff --git a/vllm-chat/service.py b/vllm-chat/service.py index 959ae7dc..ed0dccb5 100644 --- a/vllm-chat/service.py +++ b/vllm-chat/service.py @@ -1,25 +1,30 @@ -import uuid -import json -import os -from typing import AsyncGenerator, Union -from typing_extensions import Annotated import functools +import json +import logging +import os +import uuid +from typing import AsyncGenerator, Union import bentoml -from annotated_types import Ge, Le -from bentovllm_openai.utils import openai_endpoints import yaml +from annotated_types import Ge, Le from bento_constants import CONSTANT_YAML - +from bentovllm_openai.utils import openai_endpoints +from typing_extensions import Annotated CONSTANTS = yaml.safe_load(CONSTANT_YAML) ENGINE_CONFIG = CONSTANTS["engine_config"] SERVICE_CONFIG = CONSTANTS["service_config"] +CHAT_TEMPLATE = CONSTANTS.get("chat_template") + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) @functools.lru_cache(maxsize=1) def _get_gen_config(community_chat_template: str) -> dict: + logger.info(f"Load community_chat_template: {community_chat_template}") chat_template_path = os.path.join( os.path.dirname(__file__), "chat_templates", "chat_templates" ) @@ -37,17 +42,19 @@ def _get_gen_config(community_chat_template: str) -> dict: @openai_endpoints( served_model_names=[ENGINE_CONFIG["model"]], + chat_template=_get_gen_config(CHAT_TEMPLATE)["template"] if CHAT_TEMPLATE else None, chat_template_model_id=ENGINE_CONFIG["model"], ) @bentoml.service(**SERVICE_CONFIG) class VLLM: def __init__(self) -> None: - from vllm import AsyncEngineArgs, AsyncLLMEngine from transformers import AutoTokenizer + from vllm import AsyncEngineArgs, AsyncLLMEngine ENGINE_ARGS = AsyncEngineArgs(**ENGINE_CONFIG) self.engine = AsyncLLMEngine.from_engine_args(ENGINE_ARGS) self.tokenizer = AutoTokenizer.from_pretrained(ENGINE_CONFIG["model"]) + logger.info(f"VLLM service initialized with model: {ENGINE_CONFIG['model']}") @bentoml.api async def generate( @@ -94,7 +101,7 @@ class VLLM: """ from vllm import SamplingParams - if CONSTANTS.get("chat_template"): # community chat template + if CHAT_TEMPLATE: # community chat template gen_config = _get_gen_config(CONSTANTS["chat_template"]) if not stop: if gen_config["stop_str"]: