fix(serialisation): vLLM safetensors support (#324)

* fix(serilisation): vllm support for safetensors

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>

* chore: running tools

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: generalize one shot generation

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: add changelog [skip ci]

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-09-12 17:44:01 -04:00
committed by GitHub
parent c70d4edcb1
commit 35e6945e86
11 changed files with 29 additions and 46 deletions

View File

@@ -44,13 +44,3 @@ class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers.
# TODO: We will probably want to return the tokenizer here so that we can manually process this
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
return [{'generated_text': result}]