mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-03 21:18:40 -05:00
chore: ignore peft and fix adapter loading issue (#255)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -157,7 +157,7 @@ class FineTuneConfig:
|
||||
inference_mode: bool = dantic.Field(False, description='Whether to use this Adapter for inference', use_default_converter=False)
|
||||
llm_config_class: type[LLMConfig] = dantic.Field(None, description='The reference class to openllm.LLMConfig', use_default_converter=False)
|
||||
|
||||
def to_peft_config(self) -> peft.PeftConfig:
|
||||
def to_peft_config(self) -> peft.PeftConfig: # type: ignore[name-defined]
|
||||
adapter_config = self.adapter_config.copy()
|
||||
# no need for peft_type since it is internally managed by OpenLLM and PEFT
|
||||
if 'peft_type' in adapter_config: adapter_config.pop('peft_type')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# mypy: disable-error-code="name-defined,attr-defined"
|
||||
from __future__ import annotations
|
||||
import functools, inspect, logging, os, re, traceback, types, typing as t, uuid, attr, fs.path, inflection, orjson, bentoml, openllm, openllm_core, gc, pathlib, abc
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -847,7 +848,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
peft_config = self.config['fine_tune_strategies'].get(adapter_type, FineTuneConfig(adapter_type=t.cast('PeftType', adapter_type), llm_config_class=self.config_class)).train().with_config(
|
||||
**attrs
|
||||
).to_peft_config()
|
||||
wrapped_peft = peft.get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checkpointing), peft_config)
|
||||
wrapped_peft = peft.get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checkpointing), peft_config) # type: ignore[no-untyped-call]
|
||||
if DEBUG: wrapped_peft.print_trainable_parameters()
|
||||
return wrapped_peft, self.tokenizer
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ generic_embedding_runner = bentoml.Runner(
|
||||
runners: list[AbstractRunner] = [runner]
|
||||
if not runner.supports_embeddings: runners.append(generic_embedding_runner)
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=runners)
|
||||
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': ''})
|
||||
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None})
|
||||
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
|
||||
@@ -19,3 +19,13 @@ class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaToke
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
|
||||
return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item()))
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{'generated_text': result}]
|
||||
|
||||
@@ -406,8 +406,8 @@ exclude = [
|
||||
"openllm-python/src/openllm/_service.py",
|
||||
"openllm-core/src/openllm_core/_typing_compat.py",
|
||||
]
|
||||
modules = ["openllm", "openllm-core", "openllm-client"]
|
||||
mypy_path = "typings"
|
||||
modules = ["openllm", "openllm_core", "openllm_client"]
|
||||
mypy_path = "typings:openllm-core/src:openllm-client/src"
|
||||
pretty = true
|
||||
python_version = "3.8"
|
||||
show_error_codes = true
|
||||
@@ -415,6 +415,7 @@ strict = true
|
||||
warn_return_any = false
|
||||
warn_unreachable = true
|
||||
warn_unused_ignores = false
|
||||
explicit_package_bases = true
|
||||
[[tool.mypy.overrides]]
|
||||
ignore_missing_imports = true
|
||||
module = [
|
||||
|
||||
Reference in New Issue
Block a user