fix(serialisation): vLLM safetensors support (#324)

* fix(serilisation): vllm support for safetensors

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>

* chore: running tools

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: generalize one shot generation

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: add changelog [skip ci]

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-09-12 17:44:01 -04:00
committed by GitHub
parent c70d4edcb1
commit 35e6945e86
11 changed files with 29 additions and 46 deletions

View File

@@ -24,17 +24,21 @@ class HfIgnore:
@classmethod
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
if llm.__llm_backend__ == 'vllm': base = [cls.tf, cls.flax, cls.safetensors, cls.gguf]
if llm.__llm_backend__ == 'vllm':
base = [cls.tf, cls.flax, cls.gguf]
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
else: base.append(cls.safetensors)
elif llm.__llm_backend__ == 'tf': base = [cls.flax, cls.pt, cls.gguf]
elif llm.__llm_backend__ == 'flax':
base = [cls.tf, cls.pt, cls.safetensors, cls.gguf] # as of current, safetensors is not supported with flax
elif llm.__llm_backend__ == 'pt':
base = [cls.tf, cls.flax, cls.gguf]
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
else: base.append(cls.safetensors)
elif llm.__llm_backend__ == 'ggml':
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
else:
raise ValueError('Unknown backend (should never happen at all.)')
# filter out these files, since we probably don't need them for now.
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt'])
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt', 'Notice'])
return base