diff --git a/changelog.d/324.fix.md b/changelog.d/324.fix.md new file mode 100644 index 00000000..f2c94ef8 --- /dev/null +++ b/changelog.d/324.fix.md @@ -0,0 +1,3 @@ +vLLM now should support safetensors loading format, so `--serlisation` should be agnostic of backend now + +Removed some legacy check and default behaviour diff --git a/openllm-python/pyproject.toml b/openllm-python/pyproject.toml index 9d36a96b..40b3b50e 100644 --- a/openllm-python/pyproject.toml +++ b/openllm-python/pyproject.toml @@ -115,7 +115,7 @@ openai = ["openai", "tiktoken"] opt = ["flax>=0.7", "jax", "jaxlib", "tensorflow", "keras"] playground = ["jupyter", "notebook", "ipython", "jupytext", "nbformat"] starcoder = ["bitsandbytes"] -vllm = ["vllm>=0.1.4", "ray"] +vllm = ["vllm>=0.1.6", "ray"] [tool.hatch.version] fallback-version = "0.0.0" diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index bdfa1d6f..f3697ba6 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -281,8 +281,8 @@ class LLM(LLMInterface[M, T], ReprMixin): def __attrs_init__(self, config: LLMConfig, - quantize: t.Optional[LiteralQuantise], quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig]], + quantize: t.Optional[LiteralQuantise], model_id: str, model_decls: TupleAny, model_attrs: DictStrAny, @@ -446,8 +446,6 @@ class LLM(LLMInterface[M, T], ReprMixin): # in case users input `tokenizer` to __init__, default to the _model_id if quantize == 'gptq': attrs.setdefault('tokenizer', _model_id) quantization_config, attrs = infer_quantisation_config(cls, quantize, **attrs) - if quantize == 'gptq': serialisation = 'safetensors' - elif cls.__llm_backend__ == 'vllm': serialisation = 'legacy' # Currently working-in-progress # NOTE: LoRA adapter setup if adapter_map and adapter_id: @@ -534,12 +532,12 @@ class LLM(LLMInterface[M, T], ReprMixin): model_id: str, llm_config: LLMConfig, quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | None, - _adapters_mapping: AdaptersMapping | None, - _tag: bentoml.Tag, _quantize: LiteralQuantise | None, _model_version: str, + _tag: bentoml.Tag, _serialisation: t.Literal['safetensors', 'legacy'], _local: bool, + _adapters_mapping: AdaptersMapping | None, **attrs: t.Any, ): '''Initialize the LLM with given pretrained model. @@ -941,6 +939,16 @@ class LLM(LLMInterface[M, T], ReprMixin): prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs) return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs) + def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]: + max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device) + src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([])) + stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) + result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) + # Inference API returns the stop sequence + for stop_seq in stop: + if result.endswith(stop_seq): result = result[:-len(stop_seq)] + return [{'generated_text': result}] + def generate(self, prompt: str, **attrs: t.Any) -> t.List[t.Any]: # TODO: support different generation strategies, similar to self.model.generate for it in self.generate_iterator(prompt, **attrs): diff --git a/openllm-python/src/openllm/models/falcon/modeling_falcon.py b/openllm-python/src/openllm/models/falcon/modeling_falcon.py index 91ffaafb..03f524e9 100644 --- a/openllm-python/src/openllm/models/falcon/modeling_falcon.py +++ b/openllm-python/src/openllm/models/falcon/modeling_falcon.py @@ -20,13 +20,3 @@ class Falcon(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTraine attention_mask=inputs['attention_mask'], generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True) - - def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]: - max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device) - src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([])) - stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) - result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) - # Inference API returns the stop sequence - for stop_seq in stop: - if result.endswith(stop_seq): result = result[:-len(stop_seq)] - return [{'generated_text': result}] diff --git a/openllm-python/src/openllm/models/llama/modeling_llama.py b/openllm-python/src/openllm/models/llama/modeling_llama.py index 48dbb434..6c37d1fa 100644 --- a/openllm-python/src/openllm/models/llama/modeling_llama.py +++ b/openllm-python/src/openllm/models/llama/modeling_llama.py @@ -24,13 +24,3 @@ class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaToke masked_embeddings = data * mask sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1) return openllm.EmbeddingsOutput(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item())) - - def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]: - max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device) - src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([])) - stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) - result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) - # Inference API returns the stop sequence - for stop_seq in stop: - if result.endswith(stop_seq): result = result[:-len(stop_seq)] - return [{'generated_text': result}] diff --git a/openllm-python/src/openllm/models/starcoder/modeling_starcoder.py b/openllm-python/src/openllm/models/starcoder/modeling_starcoder.py index fcf396cc..c7dd3b9a 100644 --- a/openllm-python/src/openllm/models/starcoder/modeling_starcoder.py +++ b/openllm-python/src/openllm/models/starcoder/modeling_starcoder.py @@ -44,13 +44,3 @@ class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers. # TODO: We will probably want to return the tokenizer here so that we can manually process this # return (skip_special_tokens=False, clean_up_tokenization_spaces=False)) return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) - - def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]: - max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device) - src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([])) - stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) - result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) - # Inference API returns the stop sequence - for stop_seq in stop: - if result.endswith(stop_seq): result = result[:-len(stop_seq)] - return [{'generated_text': result}] diff --git a/openllm-python/src/openllm/serialisation/transformers/__init__.py b/openllm-python/src/openllm/serialisation/transformers/__init__.py index 9f0c1b07..e5aced50 100644 --- a/openllm-python/src/openllm/serialisation/transformers/__init__.py +++ b/openllm-python/src/openllm/serialisation/transformers/__init__.py @@ -66,8 +66,6 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _, tokenizer_attrs = llm.llm_parameters quantize = llm._quantize safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors') - # Disable safe serialization with vLLM - if llm.__llm_backend__ == 'vllm': safe_serialisation = False metadata: DictStrAny = {'safe_serialisation': safe_serialisation} if quantize: metadata['_quantize'] = quantize architectures = getattr(config, 'architectures', []) diff --git a/openllm-python/src/openllm/serialisation/transformers/weights.py b/openllm-python/src/openllm/serialisation/transformers/weights.py index bbc2844f..d4318987 100644 --- a/openllm-python/src/openllm/serialisation/transformers/weights.py +++ b/openllm-python/src/openllm/serialisation/transformers/weights.py @@ -24,17 +24,21 @@ class HfIgnore: @classmethod def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]: - if llm.__llm_backend__ == 'vllm': base = [cls.tf, cls.flax, cls.safetensors, cls.gguf] + if llm.__llm_backend__ == 'vllm': + base = [cls.tf, cls.flax, cls.gguf] + if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt) + else: base.append(cls.safetensors) elif llm.__llm_backend__ == 'tf': base = [cls.flax, cls.pt, cls.gguf] elif llm.__llm_backend__ == 'flax': base = [cls.tf, cls.pt, cls.safetensors, cls.gguf] # as of current, safetensors is not supported with flax elif llm.__llm_backend__ == 'pt': base = [cls.tf, cls.flax, cls.gguf] - if has_safetensors_weights(llm.model_id): base.append(cls.pt) + if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt) + else: base.append(cls.safetensors) elif llm.__llm_backend__ == 'ggml': base = [cls.tf, cls.flax, cls.pt, cls.safetensors] else: raise ValueError('Unknown backend (should never happen at all.)') # filter out these files, since we probably don't need them for now. - base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt']) + base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt', 'Notice']) return base diff --git a/openllm-python/src/openllm/utils/dummy_pt_objects.py b/openllm-python/src/openllm/utils/dummy_pt_objects.py index 4242f0b9..7e61ccf5 100644 --- a/openllm-python/src/openllm/utils/dummy_pt_objects.py +++ b/openllm-python/src/openllm/utils/dummy_pt_objects.py @@ -19,8 +19,8 @@ class GPTNeoX(metaclass=_DummyMetaclass): _backends=["torch"] def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch"]) class Llama(metaclass=_DummyMetaclass): - _backends=["torch","fairscale","sentencepiece"] - def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","fairscale","sentencepiece"]) + _backends=["torch","fairscale","sentencepiece","scipy"] + def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","fairscale","sentencepiece","scipy"]) class MPT(metaclass=_DummyMetaclass): _backends=["torch","triton","einops"] def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","triton","einops"]) diff --git a/openllm-python/src/openllm/utils/dummy_vllm_objects.py b/openllm-python/src/openllm/utils/dummy_vllm_objects.py index 048a4771..d837bf3e 100644 --- a/openllm-python/src/openllm/utils/dummy_vllm_objects.py +++ b/openllm-python/src/openllm/utils/dummy_vllm_objects.py @@ -28,8 +28,8 @@ class VLLMStarCoder(metaclass=_DummyMetaclass): _backends=["vllm","bitsandbytes"] def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","bitsandbytes"]) class VLLMLlama(metaclass=_DummyMetaclass): - _backends=["vllm","fairscale","sentencepiece"] - def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","fairscale","sentencepiece"]) + _backends=["vllm","fairscale","sentencepiece","scipy"] + def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","fairscale","sentencepiece","scipy"]) class AutoVLLM(metaclass=_DummyMetaclass): _backends=["vllm"] def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm"]) diff --git a/tools/dependencies.py b/tools/dependencies.py index a46d868f..22e11e5e 100755 --- a/tools/dependencies.py +++ b/tools/dependencies.py @@ -45,7 +45,7 @@ class Classifier: def create_python_classifier(implementation: list[str] | None = None, supported_version: list[str] | None = None) -> list[str]: if supported_version is None: supported_version = ['3.8', '3.9', '3.10', '3.11', '3.12'] if implementation is None: implementation = ['CPython', 'PyPy'] - base = [Classifier.create_classifier('language', 'Python'), Classifier.create_classifier('language', 'Python', '3'),] + base = [Classifier.create_classifier('language', 'Python'), Classifier.create_classifier('language', 'Python', '3')] base.append(Classifier.create_classifier('language', 'Python', '3', 'Only')) base.extend([Classifier.create_classifier('language', 'Python', version) for version in supported_version]) base.extend([Classifier.create_classifier('language', 'Python', 'Implementation', impl) for impl in implementation])