mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-22 16:07:24 -04:00
fix(build): check for parity (#508)
This commit is contained in:
@@ -50,15 +50,15 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
|
||||
|
||||
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
|
||||
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
|
||||
def convert_data(data: SerializableData) -> openllm.GenerateOutput | t.Sequence[openllm.GenerateOutput]:
|
||||
try:
|
||||
data = orjson.loads(data)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise ValueError(f'Failed to decode JSON data: {data}') from err
|
||||
if openllm.utils.LazyType(DictStrAny).isinstance(data):
|
||||
return openllm.GenerationOutput(**data)
|
||||
return openllm.GenerateOutput(**data)
|
||||
elif openllm.utils.LazyType(ListAny).isinstance(data):
|
||||
return [openllm.GenerationOutput(**d) for d in data]
|
||||
return [openllm.GenerateOutput(**d) for d in data]
|
||||
else:
|
||||
raise NotImplementedError(f'Data {data} has unsupported type.')
|
||||
|
||||
@@ -73,7 +73,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool:
|
||||
return s == t
|
||||
|
||||
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
|
||||
def eq_output(s: openllm.GenerateOutput, t: openllm.GenerateOutput) -> bool:
|
||||
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
|
||||
|
||||
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
|
||||
|
||||
Reference in New Issue
Block a user