fix(build): check for parity (#508)

This commit is contained in:
Aaron Pham
2023-10-16 17:33:47 -04:00
committed by GitHub
parent cb4b5acf63
commit d59a8860df
7 changed files with 41 additions and 18 deletions

View File

@@ -50,15 +50,15 @@ class ResponseComparator(JSONSnapshotExtension):
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
def convert_data(data: SerializableData) -> openllm.GenerateOutput | t.Sequence[openllm.GenerateOutput]:
try:
data = orjson.loads(data)
except orjson.JSONDecodeError as err:
raise ValueError(f'Failed to decode JSON data: {data}') from err
if openllm.utils.LazyType(DictStrAny).isinstance(data):
return openllm.GenerationOutput(**data)
return openllm.GenerateOutput(**data)
elif openllm.utils.LazyType(ListAny).isinstance(data):
return [openllm.GenerationOutput(**d) for d in data]
return [openllm.GenerateOutput(**d) for d in data]
else:
raise NotImplementedError(f'Data {data} has unsupported type.')
@@ -73,7 +73,7 @@ class ResponseComparator(JSONSnapshotExtension):
def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool:
return s == t
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
def eq_output(s: openllm.GenerateOutput, t: openllm.GenerateOutput) -> bool:
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])