mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-06 06:29:21 -05:00
fix(serialisation): vLLM safetensors support (#324)
* fix(serilisation): vllm support for safetensors Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * chore: running tools Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: generalize one shot generation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: add changelog [skip ci] Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -115,7 +115,7 @@ openai = ["openai", "tiktoken"]
|
||||
opt = ["flax>=0.7", "jax", "jaxlib", "tensorflow", "keras"]
|
||||
playground = ["jupyter", "notebook", "ipython", "jupytext", "nbformat"]
|
||||
starcoder = ["bitsandbytes"]
|
||||
vllm = ["vllm>=0.1.4", "ray"]
|
||||
vllm = ["vllm>=0.1.6", "ray"]
|
||||
|
||||
[tool.hatch.version]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
@@ -281,8 +281,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
|
||||
def __attrs_init__(self,
|
||||
config: LLMConfig,
|
||||
quantize: t.Optional[LiteralQuantise],
|
||||
quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig]],
|
||||
quantize: t.Optional[LiteralQuantise],
|
||||
model_id: str,
|
||||
model_decls: TupleAny,
|
||||
model_attrs: DictStrAny,
|
||||
@@ -446,8 +446,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
# in case users input `tokenizer` to __init__, default to the _model_id
|
||||
if quantize == 'gptq': attrs.setdefault('tokenizer', _model_id)
|
||||
quantization_config, attrs = infer_quantisation_config(cls, quantize, **attrs)
|
||||
if quantize == 'gptq': serialisation = 'safetensors'
|
||||
elif cls.__llm_backend__ == 'vllm': serialisation = 'legacy' # Currently working-in-progress
|
||||
|
||||
# NOTE: LoRA adapter setup
|
||||
if adapter_map and adapter_id:
|
||||
@@ -534,12 +532,12 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
model_id: str,
|
||||
llm_config: LLMConfig,
|
||||
quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | None,
|
||||
_adapters_mapping: AdaptersMapping | None,
|
||||
_tag: bentoml.Tag,
|
||||
_quantize: LiteralQuantise | None,
|
||||
_model_version: str,
|
||||
_tag: bentoml.Tag,
|
||||
_serialisation: t.Literal['safetensors', 'legacy'],
|
||||
_local: bool,
|
||||
_adapters_mapping: AdaptersMapping | None,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
'''Initialize the LLM with given pretrained model.
|
||||
@@ -941,6 +939,16 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs)
|
||||
return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{'generated_text': result}]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> t.List[t.Any]:
|
||||
# TODO: support different generation strategies, similar to self.model.generate
|
||||
for it in self.generate_iterator(prompt, **attrs):
|
||||
|
||||
@@ -20,13 +20,3 @@ class Falcon(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTraine
|
||||
attention_mask=inputs['attention_mask'],
|
||||
generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{'generated_text': result}]
|
||||
|
||||
@@ -24,13 +24,3 @@ class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaToke
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
|
||||
return openllm.EmbeddingsOutput(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item()))
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{'generated_text': result}]
|
||||
|
||||
@@ -44,13 +44,3 @@ class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers.
|
||||
# TODO: We will probably want to return the tokenizer here so that we can manually process this
|
||||
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
|
||||
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{'generated_text': result}]
|
||||
|
||||
@@ -66,8 +66,6 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize = llm._quantize
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
# Disable safe serialization with vLLM
|
||||
if llm.__llm_backend__ == 'vllm': safe_serialisation = False
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
|
||||
if quantize: metadata['_quantize'] = quantize
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
|
||||
@@ -24,17 +24,21 @@ class HfIgnore:
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
if llm.__llm_backend__ == 'vllm': base = [cls.tf, cls.flax, cls.safetensors, cls.gguf]
|
||||
if llm.__llm_backend__ == 'vllm':
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'tf': base = [cls.flax, cls.pt, cls.gguf]
|
||||
elif llm.__llm_backend__ == 'flax':
|
||||
base = [cls.tf, cls.pt, cls.safetensors, cls.gguf] # as of current, safetensors is not supported with flax
|
||||
elif llm.__llm_backend__ == 'pt':
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
|
||||
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml':
|
||||
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
else:
|
||||
raise ValueError('Unknown backend (should never happen at all.)')
|
||||
# filter out these files, since we probably don't need them for now.
|
||||
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt'])
|
||||
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt', 'Notice'])
|
||||
return base
|
||||
|
||||
@@ -19,8 +19,8 @@ class GPTNeoX(metaclass=_DummyMetaclass):
|
||||
_backends=["torch"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch"])
|
||||
class Llama(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","fairscale","sentencepiece"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","fairscale","sentencepiece"])
|
||||
_backends=["torch","fairscale","sentencepiece","scipy"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","fairscale","sentencepiece","scipy"])
|
||||
class MPT(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","triton","einops"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","triton","einops"])
|
||||
|
||||
@@ -28,8 +28,8 @@ class VLLMStarCoder(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm","bitsandbytes"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","bitsandbytes"])
|
||||
class VLLMLlama(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm","fairscale","sentencepiece"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","fairscale","sentencepiece"])
|
||||
_backends=["vllm","fairscale","sentencepiece","scipy"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","fairscale","sentencepiece","scipy"])
|
||||
class AutoVLLM(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm"])
|
||||
|
||||
Reference in New Issue
Block a user