fix: make sure to respect additional parameters parse (#981)

This commit is contained in:
Aaron Pham
2024-05-08 13:53:56 -04:00
committed by GitHub
parent 46433bc745
commit 42417dbdbf
4 changed files with 28 additions and 32 deletions

View File

@@ -289,7 +289,7 @@ def construct_python_options(llm_config, llm_fs):
# TODO: Add this line back once 0.5 is out, for now depends on OPENLLM_DEV_BUILD
# packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'openllm[vllm]>0.4', 'vllm>=0.3']
packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'vllm>=0.3']
packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'vllm>=0.3', 'flash-attn']
if llm_config['requirements'] is not None:
packages.extend(llm_config['requirements'])
built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')]
@@ -462,7 +462,7 @@ def build_command(
labels=labels,
models=models,
envs=[
EnvironmentEntry(name='OPENLLM_CONFIG', value=llm_config.model_dump_json()),
EnvironmentEntry(name='OPENLLM_CONFIG', value=f"'{llm_config.model_dump_json()}'"),
EnvironmentEntry(name='NVIDIA_DRIVER_CAPABILITIES', value='compute,utility'),
],
description=f"OpenLLM service for {llm_config['start_name']}",

View File

@@ -10,7 +10,7 @@ from openllm_core.utils import (
dict_filter_none,
)
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
from openllm_core._schemas import GenerationOutput, GenerationInput
from openllm_core._schemas import GenerationOutput
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
@@ -149,7 +149,7 @@ class LLM:
) -> t.AsyncGenerator[RequestOutput, None]:
from vllm import SamplingParams
config = self.config.generation_config.model_copy(update=dict_filter_none(attrs))
config = self.config.model_construct_env(**dict_filter_none(attrs))
stop_token_ids = stop_token_ids or []
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
@@ -172,12 +172,14 @@ class LLM:
top_p = 1.0 if config['temperature'] <= 1e-5 else config['top_p']
config = config.model_copy(update=dict(stop=list(stop), stop_token_ids=stop_token_ids, top_p=top_p))
params = {k: getattr(config, k, None) for k in set(inspect.signature(SamplingParams).parameters.keys())}
sampling_params = SamplingParams(**{k: v for k, v in params.items() if v is not None})
try:
async for it in self._model.generate(
prompt, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_token_ids
prompt,
sampling_params=SamplingParams(**{
k: config.__getitem__(k) for k in set(inspect.signature(SamplingParams).parameters.keys())
}),
request_id=request_id,
prompt_token_ids=prompt_token_ids,
):
yield it
except Exception as err:
@@ -191,15 +193,13 @@ class LLM:
stop_token_ids: list[int] | None = None,
request_id: str | None = None,
adapter_name: str | None = None,
*,
_generated: GenerationInput | None = None,
**attrs: t.Any,
) -> GenerationOutput:
if stop is not None:
attrs.update({'stop': stop})
if stop_token_ids is not None:
attrs.update({'stop_token_ids': stop_token_ids})
config = self.config.model_copy(update=attrs)
config = self.config.model_construct_env(**attrs)
texts, token_ids = [[]] * config['n'], [[]] * config['n']
async for result in self.generate_iterator(
prompt,

View File

@@ -58,39 +58,31 @@ class LLMService:
@core.utils.api(route='/v1/generate')
async def generate_v1(
self,
llm_config: t.Dict[str, t.Any] = pydantic.Field(default_factory=lambda: llm_config, description='LLM Config'),
prompt: str = pydantic.Field(default='What is the meaning of life?', description='Given prompt to generate from'),
prompt_token_ids: t.Optional[t.List[int]] = None,
stop: t.Optional[t.List[str]] = None,
stop_token_ids: t.Optional[t.List[int]] = None,
request_id: t.Optional[str] = None,
llm_config: t.Dict[str, t.Any] = pydantic.Field(default=llm_config, description='LLM Config'),
) -> core.GenerationOutput:
llm_config.update(stop=stop, stop_token_ids=stop_token_ids)
return await self.llm.generate(
prompt=prompt,
prompt_token_ids=prompt_token_ids,
llm_config=llm_config,
stop=stop,
stop_token_ids=stop_token_ids,
request_id=request_id,
prompt=prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **llm_config
)
@core.utils.api(route='/v1/generate_stream')
async def generate_stream_v1(
self,
llm_config: t.Dict[str, t.Any] = pydantic.Field(default_factory=lambda: llm_config, description='LLM Config'),
prompt: str = pydantic.Field(default='What is the meaning of life?', description='Given prompt to generate from'),
prompt_token_ids: t.Optional[t.List[int]] = None,
stop: t.Optional[t.List[str]] = None,
stop_token_ids: t.Optional[t.List[int]] = None,
request_id: t.Optional[str] = None,
llm_config: t.Dict[str, t.Any] = pydantic.Field(default=llm_config, description='LLM Config'),
) -> t.AsyncGenerator[str, None]:
llm_config.update(stop=stop, stop_token_ids=stop_token_ids)
async for generated in self.llm.generate_iterator(
prompt=prompt,
prompt_token_ids=prompt_token_ids,
llm_config=llm_config,
stop=stop,
stop_token_ids=stop_token_ids,
request_id=request_id,
prompt=prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **llm_config
):
yield f'data: {core.GenerationOutput.from_vllm(generated).model_dump_json()}\n\n'
yield 'data: [DONE]\n\n'
@@ -108,13 +100,15 @@ class LLMService:
@core.utils.api(route='/v1/helpers/messages')
def helpers_messages_v1(
self,
message: Annotated[t.Dict[str, t.Any], MessagesConverterInput] = MessagesConverterInput(
add_generation_prompt=False,
messages=[
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
MessageParam(role='user', content='Hi there!'),
MessageParam(role='assistant', content='Yes?'),
],
message: Annotated[t.Dict[str, t.Any], MessagesConverterInput] = pydantic.Field(
default=MessagesConverterInput(
add_generation_prompt=False,
messages=[
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
MessageParam(role='user', content='Hi there!'),
MessageParam(role='assistant', content='Yes?'),
],
)
),
) -> str:
return self.llm._tokenizer.apply_chat_template(
@@ -136,6 +130,7 @@ class LLMService:
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
MessageParam(role='user', content='Hi there!'),
MessageParam(role='assistant', content='Yes?'),
MessageParam(role='user', content='What is the meaning of life?'),
],
model=core.utils.normalise_model_name(model_id),
n=1,