chore: ignore peft and fix adapter loading issue (#255)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-08-25 04:36:35 -04:00
committed by GitHub
parent 3b40a83817
commit 08dc6ed2ba
5 changed files with 17 additions and 5 deletions

View File

@@ -157,7 +157,7 @@ class FineTuneConfig:
inference_mode: bool = dantic.Field(False, description='Whether to use this Adapter for inference', use_default_converter=False)
llm_config_class: type[LLMConfig] = dantic.Field(None, description='The reference class to openllm.LLMConfig', use_default_converter=False)
def to_peft_config(self) -> peft.PeftConfig:
def to_peft_config(self) -> peft.PeftConfig: # type: ignore[name-defined]
adapter_config = self.adapter_config.copy()
# no need for peft_type since it is internally managed by OpenLLM and PEFT
if 'peft_type' in adapter_config: adapter_config.pop('peft_type')

View File

@@ -1,3 +1,4 @@
# mypy: disable-error-code="name-defined,attr-defined"
from __future__ import annotations
import functools, inspect, logging, os, re, traceback, types, typing as t, uuid, attr, fs.path, inflection, orjson, bentoml, openllm, openllm_core, gc, pathlib, abc
from huggingface_hub import hf_hub_download
@@ -847,7 +848,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
peft_config = self.config['fine_tune_strategies'].get(adapter_type, FineTuneConfig(adapter_type=t.cast('PeftType', adapter_type), llm_config_class=self.config_class)).train().with_config(
**attrs
).to_peft_config()
wrapped_peft = peft.get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checkpointing), peft_config)
wrapped_peft = peft.get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checkpointing), peft_config) # type: ignore[no-untyped-call]
if DEBUG: wrapped_peft.print_trainable_parameters()
return wrapped_peft, self.tokenizer

View File

@@ -28,7 +28,7 @@ generic_embedding_runner = bentoml.Runner(
runners: list[AbstractRunner] = [runner]
if not runner.supports_embeddings: runners.append(generic_embedding_runner)
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=runners)
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': ''})
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None})
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)

View File

@@ -19,3 +19,13 @@ class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaToke
masked_embeddings = data * mask
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item()))
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
return [{'generated_text': result}]

View File

@@ -406,8 +406,8 @@ exclude = [
"openllm-python/src/openllm/_service.py",
"openllm-core/src/openllm_core/_typing_compat.py",
]
modules = ["openllm", "openllm-core", "openllm-client"]
mypy_path = "typings"
modules = ["openllm", "openllm_core", "openllm_client"]
mypy_path = "typings:openllm-core/src:openllm-client/src"
pretty = true
python_version = "3.8"
show_error_codes = true
@@ -415,6 +415,7 @@ strict = true
warn_return_any = false
warn_unreachable = true
warn_unused_ignores = false
explicit_package_bases = true
[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [