mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-23 16:39:48 -04:00
tests: add additional basic testing (#982)
* chore: update rebase tests Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update partial clients before removing Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * fix: update clients parsing logics to work with 0.5 Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: ignore ci runs as to run locally Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update async client tests Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update pre-commit Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
78
openllm-python/tests/_data.py
Normal file
78
openllm-python/tests/_data.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from openllm_core._typing_compat import TypedDict
|
||||
from datasets import load_dataset
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
FIXED_OUTPUT_LENGTH = 128
|
||||
|
||||
|
||||
class DatasetEntry(TypedDict):
|
||||
human: str
|
||||
gpt: str
|
||||
|
||||
|
||||
class SampledRequest(TypedDict):
|
||||
prompt: str
|
||||
prompt_length: int
|
||||
output_length: int
|
||||
|
||||
|
||||
def prepare_sharegpt_request(
|
||||
num_requests: int, tokenizer: PreTrainedTokenizerBase, max_output_length: int | None = None
|
||||
) -> list[SampledRequest]:
|
||||
def transform(examples) -> DatasetEntry:
|
||||
human, gpt = [], []
|
||||
for example in examples['conversations']:
|
||||
human.append(example[0]['value'])
|
||||
gpt.append(example[1]['value'])
|
||||
return {'human': human, 'gpt': gpt}
|
||||
|
||||
def process(examples, tokenizer, max_output_length: t.Optional[int]):
|
||||
# Tokenize the 'human' and 'gpt' values in batches
|
||||
prompt_token_ids = tokenizer(examples['human']).input_ids
|
||||
completion_token_ids = tokenizer(examples['gpt']).input_ids
|
||||
|
||||
# Create the transformed entries
|
||||
return {
|
||||
'prompt': examples['human'],
|
||||
'prompt_length': [len(ids) for ids in prompt_token_ids],
|
||||
'output_length': [
|
||||
len(ids) if max_output_length is None else FIXED_OUTPUT_LENGTH for ids in completion_token_ids
|
||||
],
|
||||
}
|
||||
|
||||
def filter_length(examples) -> list[bool]:
|
||||
result = []
|
||||
for prompt_length, output_length in zip(examples['prompt_length'], examples['output_length']):
|
||||
if prompt_length < 4 or output_length < 4:
|
||||
result.append(False)
|
||||
elif prompt_length > 1024 or prompt_length + output_length > 2048:
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(True)
|
||||
return result
|
||||
|
||||
return (
|
||||
(
|
||||
dataset := load_dataset(
|
||||
'anon8231489123/ShareGPT_Vicuna_unfiltered',
|
||||
data_files='ShareGPT_V3_unfiltered_cleaned_split.json',
|
||||
split='train',
|
||||
)
|
||||
)
|
||||
.filter(lambda example: len(example['conversations']) >= 2, num_proc=8)
|
||||
.map(transform, remove_columns=dataset.column_names, batched=True)
|
||||
.map(
|
||||
process,
|
||||
fn_kwargs={'tokenizer': tokenizer, 'max_output_length': max_output_length},
|
||||
remove_columns=['human', 'gpt'],
|
||||
batched=True,
|
||||
)
|
||||
.filter(filter_length, batched=True)
|
||||
.shuffle(seed=42)
|
||||
.to_list()[:num_requests]
|
||||
)
|
||||
Reference in New Issue
Block a user