mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 02:20:32 -04:00
Merge pull request #1 from bentoml/rick-0604-pr-chat-template
feat: Use community chat template as the source of truth. Fall back to HF tokenizer template
This commit is contained in:
@@ -88,6 +88,7 @@ def openai_endpoints(
|
||||
if self.chat_template is None and chat_template_model_id is not None:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# If no community chat template is provided, use the tokenizer's chat template
|
||||
_tokenizer = AutoTokenizer.from_pretrained(chat_template_model_id)
|
||||
self.chat_template = _tokenizer.chat_template
|
||||
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator, Union
|
||||
from typing_extensions import Annotated
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
import bentoml
|
||||
from annotated_types import Ge, Le
|
||||
from bentovllm_openai.utils import openai_endpoints
|
||||
import yaml
|
||||
from annotated_types import Ge, Le
|
||||
from bento_constants import CONSTANT_YAML
|
||||
|
||||
from bentovllm_openai.utils import openai_endpoints
|
||||
from typing_extensions import Annotated
|
||||
|
||||
CONSTANTS = yaml.safe_load(CONSTANT_YAML)
|
||||
|
||||
ENGINE_CONFIG = CONSTANTS["engine_config"]
|
||||
SERVICE_CONFIG = CONSTANTS["service_config"]
|
||||
CHAT_TEMPLATE = CONSTANTS.get("chat_template")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_gen_config(community_chat_template: str) -> dict:
|
||||
logger.info(f"Load community_chat_template: {community_chat_template}")
|
||||
chat_template_path = os.path.join(
|
||||
os.path.dirname(__file__), "chat_templates", "chat_templates"
|
||||
)
|
||||
@@ -37,17 +42,19 @@ def _get_gen_config(community_chat_template: str) -> dict:
|
||||
|
||||
@openai_endpoints(
|
||||
served_model_names=[ENGINE_CONFIG["model"]],
|
||||
chat_template=_get_gen_config(CHAT_TEMPLATE)["template"] if CHAT_TEMPLATE else None,
|
||||
chat_template_model_id=ENGINE_CONFIG["model"],
|
||||
)
|
||||
@bentoml.service(**SERVICE_CONFIG)
|
||||
class VLLM:
|
||||
def __init__(self) -> None:
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
||||
|
||||
ENGINE_ARGS = AsyncEngineArgs(**ENGINE_CONFIG)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(ENGINE_ARGS)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(ENGINE_CONFIG["model"])
|
||||
logger.info(f"VLLM service initialized with model: {ENGINE_CONFIG['model']}")
|
||||
|
||||
@bentoml.api
|
||||
async def generate(
|
||||
@@ -94,7 +101,7 @@ class VLLM:
|
||||
"""
|
||||
from vllm import SamplingParams
|
||||
|
||||
if CONSTANTS.get("chat_template"): # community chat template
|
||||
if CHAT_TEMPLATE: # community chat template
|
||||
gen_config = _get_gen_config(CONSTANTS["chat_template"])
|
||||
if not stop:
|
||||
if gen_config["stop_str"]:
|
||||
|
||||
Reference in New Issue
Block a user