mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-26 00:07:51 -05:00
fix(serialisation): vLLM safetensors support (#324)
* fix(serilisation): vllm support for safetensors Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * chore: running tools Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: generalize one shot generation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: add changelog [skip ci] Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -66,8 +66,6 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize = llm._quantize
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
# Disable safe serialization with vLLM
|
||||
if llm.__llm_backend__ == 'vllm': safe_serialisation = False
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
|
||||
if quantize: metadata['_quantize'] = quantize
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
|
||||
@@ -24,17 +24,21 @@ class HfIgnore:
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
if llm.__llm_backend__ == 'vllm': base = [cls.tf, cls.flax, cls.safetensors, cls.gguf]
|
||||
if llm.__llm_backend__ == 'vllm':
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'tf': base = [cls.flax, cls.pt, cls.gguf]
|
||||
elif llm.__llm_backend__ == 'flax':
|
||||
base = [cls.tf, cls.pt, cls.safetensors, cls.gguf] # as of current, safetensors is not supported with flax
|
||||
elif llm.__llm_backend__ == 'pt':
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
|
||||
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml':
|
||||
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
else:
|
||||
raise ValueError('Unknown backend (should never happen at all.)')
|
||||
# filter out these files, since we probably don't need them for now.
|
||||
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt'])
|
||||
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt', 'Notice'])
|
||||
return base
|
||||
|
||||
Reference in New Issue
Block a user