mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 10:29:36 -04:00
fix(yapf): align weird new lines break [generated] [skip ci] (#284)
fix(yapf): align weird new lines break Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -28,19 +28,10 @@ class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers.
|
||||
import transformers
|
||||
torch_dtype, device_map = attrs.pop('torch_dtype', torch.float16), attrs.pop('device_map', 'auto')
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.add_special_tokens({
|
||||
'additional_special_tokens': [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
|
||||
'pad_token': EOD
|
||||
})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
**attrs)
|
||||
tokenizer.add_special_tokens({'additional_special_tokens': [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], 'pad_token': EOD})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag,
|
||||
model,
|
||||
custom_objects={'tokenizer': tokenizer},
|
||||
labels=generate_labels(self))
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -49,26 +40,21 @@ class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers.
|
||||
with torch.inference_mode():
|
||||
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
|
||||
# NOTE: support fine-tuning starcoder
|
||||
result_tensor = self.model.generate(
|
||||
self.tokenizer.encode(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
result_tensor = self.model.generate(self.tokenizer.encode(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
# TODO: We will probably want to return the tokenizer here so that we can manually process this
|
||||
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
|
||||
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str],
|
||||
**preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(
|
||||
prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop(
|
||||
'stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria',
|
||||
openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(
|
||||
self.model.generate(encoded_inputs['input_ids'],
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
|
||||
Reference in New Issue
Block a user