diff --git a/openllm-python/src/openllm/_assign.py b/openllm-python/src/openllm/_assign.py index 8113800e..c10b5c73 100644 --- a/openllm-python/src/openllm/_assign.py +++ b/openllm-python/src/openllm/_assign.py @@ -1,6 +1,7 @@ '''LLM assignment magik.''' from __future__ import annotations import functools +import math import traceback import typing as t @@ -50,13 +51,15 @@ def load_model(fn: load_model_protocol[M, T]) -> t.Callable[[LLM[M, T]], M | vll @functools.wraps(fn) def inner(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine: if self.__llm_backend__ == 'vllm': + num_gpus, dev = 1, device_count() + if dev >= 2: num_gpus = dev if dev // 2 == 0 else math.ceil(dev / 2) # TODO: Do some more processing with token_id once we support token streaming try: return vllm.LLMEngine.from_engine_args( vllm.EngineArgs(model=self._bentomodel.path, tokenizer=self._bentomodel.path if self.tokenizer_id == 'local' else self.tokenizer_id, tokenizer_mode='auto', - tensor_parallel_size=1 if device_count() < 2 else device_count(), + tensor_parallel_size=num_gpus, dtype='auto', worker_use_ray=False)) except Exception as err: diff --git a/openllm-python/src/openllm/serialisation/transformers/weights.py b/openllm-python/src/openllm/serialisation/transformers/weights.py index e1218444..bbc2844f 100644 --- a/openllm-python/src/openllm/serialisation/transformers/weights.py +++ b/openllm-python/src/openllm/serialisation/transformers/weights.py @@ -20,16 +20,21 @@ class HfIgnore: pt = '*.bin' tf = '*.h5' flax = '*.msgpack' + gguf = '*.gguf' @classmethod def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]: - if llm.__llm_backend__ == 'vllm': base = [cls.tf, cls.flax, cls.safetensors] - elif llm.__llm_backend__ == 'tf': base = [cls.flax, cls.pt] + if llm.__llm_backend__ == 'vllm': base = [cls.tf, cls.flax, cls.safetensors, cls.gguf] + elif llm.__llm_backend__ == 'tf': base = [cls.flax, cls.pt, cls.gguf] elif llm.__llm_backend__ == 'flax': - base = [cls.tf, cls.pt, cls.safetensors] # as of current, safetensors is not supported with flax - else: - base = [cls.tf, cls.flax] + base = [cls.tf, cls.pt, cls.safetensors, cls.gguf] # as of current, safetensors is not supported with flax + elif llm.__llm_backend__ == 'pt': + base = [cls.tf, cls.flax, cls.gguf] if has_safetensors_weights(llm.model_id): base.append(cls.pt) + elif llm.__llm_backend__ == 'ggml': + base = [cls.tf, cls.flax, cls.pt, cls.safetensors] + else: + raise ValueError('Unknown backend (should never happen at all.)') # filter out these files, since we probably don't need them for now. base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt']) return base