mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-19 12:58:00 -05:00
fix: make sure to respect additional parameters parse (#981)
This commit is contained in:
@@ -289,7 +289,7 @@ def construct_python_options(llm_config, llm_fs):
|
||||
|
||||
# TODO: Add this line back once 0.5 is out, for now depends on OPENLLM_DEV_BUILD
|
||||
# packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'openllm[vllm]>0.4', 'vllm>=0.3']
|
||||
packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'vllm>=0.3']
|
||||
packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'vllm>=0.3', 'flash-attn']
|
||||
if llm_config['requirements'] is not None:
|
||||
packages.extend(llm_config['requirements'])
|
||||
built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')]
|
||||
@@ -462,7 +462,7 @@ def build_command(
|
||||
labels=labels,
|
||||
models=models,
|
||||
envs=[
|
||||
EnvironmentEntry(name='OPENLLM_CONFIG', value=llm_config.model_dump_json()),
|
||||
EnvironmentEntry(name='OPENLLM_CONFIG', value=f"'{llm_config.model_dump_json()}'"),
|
||||
EnvironmentEntry(name='NVIDIA_DRIVER_CAPABILITIES', value='compute,utility'),
|
||||
],
|
||||
description=f"OpenLLM service for {llm_config['start_name']}",
|
||||
|
||||
@@ -10,7 +10,7 @@ from openllm_core.utils import (
|
||||
dict_filter_none,
|
||||
)
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
|
||||
from openllm_core._schemas import GenerationOutput, GenerationInput
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
|
||||
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
|
||||
|
||||
@@ -149,7 +149,7 @@ class LLM:
|
||||
) -> t.AsyncGenerator[RequestOutput, None]:
|
||||
from vllm import SamplingParams
|
||||
|
||||
config = self.config.generation_config.model_copy(update=dict_filter_none(attrs))
|
||||
config = self.config.model_construct_env(**dict_filter_none(attrs))
|
||||
|
||||
stop_token_ids = stop_token_ids or []
|
||||
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
|
||||
@@ -172,12 +172,14 @@ class LLM:
|
||||
top_p = 1.0 if config['temperature'] <= 1e-5 else config['top_p']
|
||||
config = config.model_copy(update=dict(stop=list(stop), stop_token_ids=stop_token_ids, top_p=top_p))
|
||||
|
||||
params = {k: getattr(config, k, None) for k in set(inspect.signature(SamplingParams).parameters.keys())}
|
||||
sampling_params = SamplingParams(**{k: v for k, v in params.items() if v is not None})
|
||||
|
||||
try:
|
||||
async for it in self._model.generate(
|
||||
prompt, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_token_ids
|
||||
prompt,
|
||||
sampling_params=SamplingParams(**{
|
||||
k: config.__getitem__(k) for k in set(inspect.signature(SamplingParams).parameters.keys())
|
||||
}),
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
):
|
||||
yield it
|
||||
except Exception as err:
|
||||
@@ -191,15 +193,13 @@ class LLM:
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
*,
|
||||
_generated: GenerationInput | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> GenerationOutput:
|
||||
if stop is not None:
|
||||
attrs.update({'stop': stop})
|
||||
if stop_token_ids is not None:
|
||||
attrs.update({'stop_token_ids': stop_token_ids})
|
||||
config = self.config.model_copy(update=attrs)
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for result in self.generate_iterator(
|
||||
prompt,
|
||||
|
||||
@@ -58,39 +58,31 @@ class LLMService:
|
||||
@core.utils.api(route='/v1/generate')
|
||||
async def generate_v1(
|
||||
self,
|
||||
llm_config: t.Dict[str, t.Any] = pydantic.Field(default_factory=lambda: llm_config, description='LLM Config'),
|
||||
prompt: str = pydantic.Field(default='What is the meaning of life?', description='Given prompt to generate from'),
|
||||
prompt_token_ids: t.Optional[t.List[int]] = None,
|
||||
stop: t.Optional[t.List[str]] = None,
|
||||
stop_token_ids: t.Optional[t.List[int]] = None,
|
||||
request_id: t.Optional[str] = None,
|
||||
llm_config: t.Dict[str, t.Any] = pydantic.Field(default=llm_config, description='LLM Config'),
|
||||
) -> core.GenerationOutput:
|
||||
llm_config.update(stop=stop, stop_token_ids=stop_token_ids)
|
||||
return await self.llm.generate(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
llm_config=llm_config,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
request_id=request_id,
|
||||
prompt=prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **llm_config
|
||||
)
|
||||
|
||||
@core.utils.api(route='/v1/generate_stream')
|
||||
async def generate_stream_v1(
|
||||
self,
|
||||
llm_config: t.Dict[str, t.Any] = pydantic.Field(default_factory=lambda: llm_config, description='LLM Config'),
|
||||
prompt: str = pydantic.Field(default='What is the meaning of life?', description='Given prompt to generate from'),
|
||||
prompt_token_ids: t.Optional[t.List[int]] = None,
|
||||
stop: t.Optional[t.List[str]] = None,
|
||||
stop_token_ids: t.Optional[t.List[int]] = None,
|
||||
request_id: t.Optional[str] = None,
|
||||
llm_config: t.Dict[str, t.Any] = pydantic.Field(default=llm_config, description='LLM Config'),
|
||||
) -> t.AsyncGenerator[str, None]:
|
||||
llm_config.update(stop=stop, stop_token_ids=stop_token_ids)
|
||||
async for generated in self.llm.generate_iterator(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
llm_config=llm_config,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
request_id=request_id,
|
||||
prompt=prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **llm_config
|
||||
):
|
||||
yield f'data: {core.GenerationOutput.from_vllm(generated).model_dump_json()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
@@ -108,13 +100,15 @@ class LLMService:
|
||||
@core.utils.api(route='/v1/helpers/messages')
|
||||
def helpers_messages_v1(
|
||||
self,
|
||||
message: Annotated[t.Dict[str, t.Any], MessagesConverterInput] = MessagesConverterInput(
|
||||
add_generation_prompt=False,
|
||||
messages=[
|
||||
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
MessageParam(role='user', content='Hi there!'),
|
||||
MessageParam(role='assistant', content='Yes?'),
|
||||
],
|
||||
message: Annotated[t.Dict[str, t.Any], MessagesConverterInput] = pydantic.Field(
|
||||
default=MessagesConverterInput(
|
||||
add_generation_prompt=False,
|
||||
messages=[
|
||||
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
MessageParam(role='user', content='Hi there!'),
|
||||
MessageParam(role='assistant', content='Yes?'),
|
||||
],
|
||||
)
|
||||
),
|
||||
) -> str:
|
||||
return self.llm._tokenizer.apply_chat_template(
|
||||
@@ -136,6 +130,7 @@ class LLMService:
|
||||
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
MessageParam(role='user', content='Hi there!'),
|
||||
MessageParam(role='assistant', content='Yes?'),
|
||||
MessageParam(role='user', content='What is the meaning of life?'),
|
||||
],
|
||||
model=core.utils.normalise_model_name(model_id),
|
||||
n=1,
|
||||
|
||||
Reference in New Issue
Block a user