Revert "chore: bump vllm to 0.4.3" (#3)

This commit is contained in:
bojiang
2024-06-12 20:03:58 +08:00
committed by GitHub
parent 41ad0a9b01
commit 52c625ec2f
3 changed files with 81 additions and 62 deletions

View File

@@ -14,37 +14,19 @@ T = t.TypeVar("T", bound=object)
if t.TYPE_CHECKING:
from vllm import AsyncLLMEngine
def openai_endpoints(
model_id: str,
response_role: str = "assistant",
served_model_names: t.Optional[list[str]] = None,
chat_template: t.Optional[str] = None,
chat_template_model_id: t.Optional[str] = None,
default_completion_parameters: t.Optional[t.Dict[str, t.Any]] = None,
default_chat_completion_parameters: t.Optional[t.Dict[str, t.Any]] = None,
):
if served_model_names is None:
served_model_names = [model_id]
def openai_endpoints(
served_model_names: list[str],
response_role: str = "assistant",
chat_template: t.Optional[str] = None,
chat_template_model_id: t.Optional[str] = None,
):
def openai_wrapper(svc: Service[T]):
cls = svc.inner
app = FastAPI()
# make sure default_*_parameters are in valid format
if default_completion_parameters is not None:
assert "prompt" not in default_completion_parameters
assert CompletionRequest(
prompt="", model="", **default_completion_parameters
)
if default_chat_completion_parameters is not None:
assert "messages" not in default_chat_completion_parameters
assert ChatCompletionRequest(
messages=[], model="", **default_chat_completion_parameters
)
class new_cls(cls):
def __init__(self):
@@ -56,32 +38,65 @@ def openai_endpoints(
# That's also why we put these codes inside class's
# `__init__` function
import bentoml
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_completion import (
OpenAIServingCompletion,
)
# we can do this because worker/engine_user_ray is always False for us
model_config = self.engine.engine.get_model_config()
# https://github.com/vllm-project/vllm/issues/2683
class PatchedOpenAIServingChat(OpenAIServingChat):
def __init__(
self,
engine: AsyncLLMEngine,
served_model_names: list[str],
response_role: str,
chat_template=None,
):
super(OpenAIServingChat, self).__init__(
engine=engine,
served_model_names=served_model_names,
lora_modules=None,
)
self.response_role = response_role
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
event_loop.create_task(
self._load_chat_template(chat_template)
)
else:
asyncio.run(self._load_chat_template(chat_template))
async def _load_chat_template(self, chat_template):
# Simply making this function async is usually already enough to give the parent
# class time to load the tokenizer (so usually no sleeping happens here)
# However, it feels safer to be explicit about this since asyncio does not
# guarantee the order in which scheduled tasks are run
while self.tokenizer is None:
await asyncio.sleep(0.1)
return super()._load_chat_template(chat_template)
self.openai_serving_completion = OpenAIServingCompletion(
engine=self.engine,
served_model_names=served_model_names,
model_config=model_config,
lora_modules=None,
)
self.chat_template = chat_template
if self.chat_template is None and chat_template_model_id is not None:
from transformers import AutoTokenizer
# If no community chat template is provided, use the tokenizer's chat template
_tokenizer = AutoTokenizer.from_pretrained(chat_template_model_id)
self.chat_template = _tokenizer.chat_template
self.openai_serving_chat = OpenAIServingChat(
self.openai_serving_chat = PatchedOpenAIServingChat(
engine=self.engine,
served_model_names=served_model_names,
response_role=response_role,
chat_template=self.chat_template,
model_config=model_config,
)
@app.get("/models")
@@ -91,38 +106,37 @@ def openai_endpoints(
@app.post("/chat/completions")
async def create_chat_completion(
request: ChatCompletionRequest,
raw_request: Request
request: ChatCompletionRequest, raw_request: Request
):
if default_chat_completion_parameters is not None:
for k, v in default_chat_completion_parameters.items():
if k not in request.__fields_set__:
setattr(request, k, v)
generator = await self.openai_serving_chat.create_chat_completion(
request, raw_request)
request, raw_request
)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
return JSONResponse(
content=generator.model_dump(), status_code=generator.code
)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
return StreamingResponse(
content=generator, media_type="text/event-stream"
)
else:
return JSONResponse(content=generator.model_dump())
@app.post("/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
if default_completion_parameters is not None:
for k, v in default_completion_parameters.items():
if k not in request.__fields_set__:
setattr(request, k, v)
async def create_completion(
request: CompletionRequest, raw_request: Request
):
generator = await self.openai_serving_completion.create_completion(
request, raw_request)
request, raw_request
)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
return JSONResponse(
content=generator.model_dump(), status_code=generator.code
)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
return StreamingResponse(
content=generator, media_type="text/event-stream"
)
else:
return JSONResponse(content=generator.model_dump())
@@ -137,8 +151,9 @@ def openai_endpoints(
# helper function to make a httpx client for BentoML service
def _make_httpx_client(url, svc):
import httpx
from urllib.parse import urlparse
import httpx
from bentoml._internal.utils.uri import uri_to_path
timeout = svc.config["traffic"]["timeout"]
@@ -154,9 +169,12 @@ def _make_httpx_client(url, svc):
elif parsed.scheme == "tcp":
target_url = f"http://{parsed.netloc}"
return httpx.Client(
transport=transport,
timeout=timeout,
follow_redirects=True,
headers=headers,
), target_url
return (
httpx.Client(
transport=transport,
timeout=timeout,
follow_redirects=True,
headers=headers,
),
target_url,
)

View File

@@ -1,3 +1,4 @@
vllm==0.4.3
torch==2.3.0
vllm==0.4.2
transformers==4.41.0
pyyaml

View File

@@ -41,7 +41,7 @@ def _get_gen_config(community_chat_template: str) -> dict:
@openai_endpoints(
model_id=ENGINE_CONFIG["model"],
served_model_names=[ENGINE_CONFIG["model"]],
chat_template=_get_gen_config(CHAT_TEMPLATE)["template"] if CHAT_TEMPLATE else None,
chat_template_model_id=ENGINE_CONFIG["model"],
)