fix(build): check for parity (#508)

This commit is contained in:
Aaron Pham
2023-10-16 17:33:47 -04:00
committed by GitHub
parent cb4b5acf63
commit d59a8860df
7 changed files with 41 additions and 18 deletions

View File

@@ -6,7 +6,8 @@ from . import utils as utils
from ._configuration import GenerationConfig as GenerationConfig
from ._configuration import LLMConfig as LLMConfig
from ._configuration import SamplingParams as SamplingParams
from ._schema import GenerationInput as GenerationInput
from ._schema import GenerateInput as GenerateInput
from ._schema import GenerateOutput as GenerateOutput
from ._schema import GenerationOutput as GenerationOutput
from ._schema import HfAgentInput as HfAgentInput
from ._schema import MetadataOutput as MetadataOutput

View File

@@ -10,12 +10,13 @@ from openllm_core._configuration import GenerationConfig
from openllm_core._configuration import LLMConfig
from .utils import bentoml_cattr
from .utils import gen_random_uuid
if t.TYPE_CHECKING:
import vllm
@attr.frozen(slots=True)
class GenerationInput:
class GenerateInput:
prompt: str
llm_config: LLMConfig
adapter_name: str | None = attr.field(default=None)
@@ -31,13 +32,13 @@ class GenerationInput:
return cls(**data)
@classmethod
def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]:
def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerateInput]:
import openllm
return cls.from_llm_config(openllm.AutoConfig.for_model(model_name, **attrs))
@classmethod
def from_llm_config(cls, llm_config: LLMConfig) -> type[GenerationInput]:
return attr.make_class(inflection.camelize(llm_config['model_name']) + 'GenerationInput',
def from_llm_config(cls, llm_config: LLMConfig) -> type[GenerateInput]:
return attr.make_class(inflection.camelize(llm_config['model_name']) + 'GenerateInput',
attrs={
'prompt': attr.field(type=str),
'llm_config': attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__)),
@@ -45,7 +46,7 @@ class GenerationInput:
})
@attr.frozen(slots=True)
class GenerationOutput:
class GenerateOutput:
responses: t.List[t.Any]
configuration: t.Dict[str, t.Any]
@@ -87,3 +88,22 @@ def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> dict[str, t.An
class HfAgentInput:
inputs: str
parameters: t.Dict[str, t.Any]
FinishReason = t.Literal['length', 'stop']
@attr.define
class CompletionChunk:
index: int
text: str
token_ids: t.List[int]
cumulative_logprob: float
logprobs: t.Optional[t.List[t.Dict[int, float]]] = None
finish_reason: t.Optional[FinishReason] = None
@attr.define
class GenerationOutput:
prompt: str
finished: bool
outputs: t.List[CompletionChunk]
prompt_token_ids: t.Optional[t.List[int]] = attr.field(default=None)
request_id: str = attr.field(factory=lambda: gen_random_uuid())

View File

@@ -45,11 +45,11 @@ If a question does not make any sense, or is not factually coherent, explain why
'''
SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = '[INST]', '[/INST]', '<<SYS>>', '</s>', '<s>'
# TODO: support history and v1 prompt implementation
_v1_prompt, _v2_prompt = '''{instruction}''', '''{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key} '''.format(start_key=SINST_KEY,
sys_key=SYS_KEY,
system_message='{system_message}',
instruction='{instruction}',
end_key=EINST_KEY)
_v1_prompt, _v2_prompt = '''{instruction}''', '''{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key}\n'''.format(start_key=SINST_KEY,
sys_key=SYS_KEY,
system_message='{system_message}',
instruction='{instruction}',
end_key=EINST_KEY)
PROMPT_MAPPING = {'v1': _v1_prompt, 'v2': _v2_prompt}
def _get_prompt(model_type: t.Literal['v1', 'v2']) -> PromptTemplate:

View File

@@ -15,7 +15,7 @@ from . import exceptions as exceptions, utils as utils
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams
from openllm_core._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource
from openllm_core._schema import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput, unmarshal_vllm_outputs as unmarshal_vllm_outputs
from openllm_core._schema import GenerateInput as GenerateInput, GenerateOutput as GenerateOutput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput
from openllm_core.config import AutoConfig as AutoConfig, CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig
if openllm_core.utils.DEBUG:

View File

@@ -36,7 +36,7 @@ svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[r
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None})
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerateOutput:
echo = input_dict.pop('echo', False)
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
config = qa_inputs.llm_config.model_dump()
@@ -46,7 +46,7 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
if responses is None: raise ValueError("'responses' should not be None.")
else:
responses = await runner.generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **config)
return openllm.GenerationOutput(responses=responses, configuration=config)
return openllm.GenerateOutput(responses=responses, configuration=config)
@svc.api(route='/v1/generate_stream', input=_JsonInput, output=bentoml.io.Text(content_type='text/event-stream'))
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:

View File

@@ -522,6 +522,8 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
quantize=env['quantize_value'],
serialisation=_serialisation,
**attrs)
# FIX: This is a patch for _service_vars injection
if 'OPENLLM_MODEL_ID' not in os.environ: os.environ['OPENLLM_MODEL_ID'] = llm.model_id
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})

View File

@@ -50,15 +50,15 @@ class ResponseComparator(JSONSnapshotExtension):
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
def convert_data(data: SerializableData) -> openllm.GenerateOutput | t.Sequence[openllm.GenerateOutput]:
try:
data = orjson.loads(data)
except orjson.JSONDecodeError as err:
raise ValueError(f'Failed to decode JSON data: {data}') from err
if openllm.utils.LazyType(DictStrAny).isinstance(data):
return openllm.GenerationOutput(**data)
return openllm.GenerateOutput(**data)
elif openllm.utils.LazyType(ListAny).isinstance(data):
return [openllm.GenerationOutput(**d) for d in data]
return [openllm.GenerateOutput(**d) for d in data]
else:
raise NotImplementedError(f'Data {data} has unsupported type.')
@@ -73,7 +73,7 @@ class ResponseComparator(JSONSnapshotExtension):
def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool:
return s == t
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
def eq_output(s: openllm.GenerateOutput, t: openllm.GenerateOutput) -> bool:
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])