mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-21 14:02:20 -05:00
chore: ignore peft and fix adapter loading issue (#255)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -19,3 +19,13 @@ class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaToke
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
|
||||
return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item()))
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{'generated_text': result}]
|
||||
|
||||
Reference in New Issue
Block a user