mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-02 11:52:19 -05:00
chore(style): reduce line length and truncate compression
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,16 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import functools, inspect, logging, os, re, traceback, types, typing as t, uuid, attr, fs.path, inflection, orjson, bentoml, openllm, openllm_core, gc
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
import functools, inspect, logging, os, re, traceback, types, typing as t, uuid, attr, fs.path, inflection, orjson, bentoml, openllm, openllm_core, gc, pathlib, abc
|
||||
from huggingface_hub import hf_hub_download
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
|
||||
from openllm_core._configuration import FineTuneConfig, LLMConfig, _object_getattribute, _setattr_class
|
||||
from ._quantisation import infer_quantisation_config
|
||||
from openllm_core._schema import unmarshal_vllm_outputs
|
||||
from .exceptions import ForbiddenAttributeError, GpuNotAvailableError, OpenLLMException
|
||||
from .models.auto import AutoConfig
|
||||
from openllm_core.utils import DEBUG, ENV_VARS_TRUE_VALUES, MYPY, EnvVarMixin, LazyLoader, ReprMixin, apply, bentoml_cattr, codegen, device_count, first_not_none, generate_hash_from_file, is_peft_available, is_torch_available, non_intrusive_setattr, normalize_attrs_to_model_tokenizer_pair, resolve_filepath, validate_is_path
|
||||
from ._quantisation import infer_quantisation_config
|
||||
from .exceptions import ForbiddenAttributeError, GpuNotAvailableError, OpenLLMException
|
||||
from .utils import infer_auto_class
|
||||
from openllm_core._typing_compat import AdaptersMapping, AdaptersTuple, AnyCallable, AdapterType, LiteralRuntime, DictStrAny, ListStr, LLMEmbeddings, LLMRunnable, LLMRunner, ModelSignatureDict as _ModelSignatureDict, PeftAdapterOutput, TupleAny, NotRequired, overload, M, T, LiteralString
|
||||
|
||||
@@ -68,7 +64,7 @@ def resolve_peft_config_type(adapter_map: dict[str, str | None]) -> AdaptersMapp
|
||||
resolved[_peft_type] += (_AdaptersTuple((path_or_adapter_id, resolve_name, resolved_config)),)
|
||||
return resolved
|
||||
_reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"}
|
||||
class LLMInterface(ABC, t.Generic[M, T]):
|
||||
class LLMInterface(abc.ABC, t.Generic[M, T]):
|
||||
"""This defines the loose contract for all openllm.LLM implementations."""
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[DictStrAny, DictStrAny] | None:
|
||||
@@ -91,7 +87,7 @@ class LLMInterface(ABC, t.Generic[M, T]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@abc.abstractmethod
|
||||
def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any:
|
||||
"""The implementation for text generation from given prompt.
|
||||
|
||||
@@ -141,7 +137,7 @@ class LLMInterface(ABC, t.Generic[M, T]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_pretrained(self, save_directory: str | Path, **attrs: t.Any) -> None:
|
||||
def save_pretrained(self, save_directory: str | pathlib.Path, **attrs: t.Any) -> None:
|
||||
"""This function defines how this model can be saved to local store.
|
||||
|
||||
This will be called during ``import_model``. By default, it will use ``openllm.serialisation.save_pretrained``.
|
||||
@@ -234,7 +230,7 @@ class _llm_post_init_wrapper(t.Generic[M, T], t.Protocol):
|
||||
def __call__(self, llm: LLM[M, T]) -> T:
|
||||
...
|
||||
class _save_pretrained_wrapper(t.Generic[M, T], t.Protocol):
|
||||
def __call__(self, llm: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None:
|
||||
def __call__(self, llm: LLM[M, T], save_directory: str | pathlib.Path, **attrs: t.Any) -> None:
|
||||
...
|
||||
_object_setattr = object.__setattr__
|
||||
# NOTE: the following wrapper are a light meta ops for wrapping default params to internal methods implementation.
|
||||
@@ -250,7 +246,9 @@ def _wrapped_import_model(f: _import_model_wrapper[bentoml.Model, M, T]) -> t.Ca
|
||||
return wrapper
|
||||
_DEFAULT_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
||||
def get_engine_args(llm: LLM[M, T], tokenizer: str = _DEFAULT_TOKENIZER) -> vllm.EngineArgs:
|
||||
return vllm.EngineArgs(model=llm._bentomodel.path, tokenizer=tokenizer, tokenizer_mode="auto", tensor_parallel_size=1 if device_count() < 2 else device_count(), dtype="auto", worker_use_ray=False)
|
||||
return vllm.EngineArgs(
|
||||
model=llm._bentomodel.path, tokenizer=tokenizer, tokenizer_mode="auto", tensor_parallel_size=1 if device_count() < 2 else device_count(), dtype="auto", worker_use_ray=False
|
||||
)
|
||||
def _wrapped_load_model(f: _load_model_wrapper[M, T]) -> t.Callable[[LLM[M, T]], M | vllm.LLMEngine]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine:
|
||||
@@ -279,12 +277,13 @@ def _wrapped_llm_post_init(f: _llm_post_init_wrapper[M, T]) -> t.Callable[[LLM[M
|
||||
f(self)
|
||||
|
||||
return wrapper
|
||||
def _wrapped_save_pretrained(f: _save_pretrained_wrapper[M, T]) -> t.Callable[[LLM[M, T], str | Path], None]:
|
||||
def _wrapped_save_pretrained(f: _save_pretrained_wrapper[M, T]) -> t.Callable[[LLM[M, T], str | pathlib.Path], None]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(self: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None:
|
||||
if isinstance(save_directory, Path): save_directory = str(save_directory)
|
||||
def wrapper(self: LLM[M, T], save_directory: str | pathlib.Path, **attrs: t.Any) -> None:
|
||||
if isinstance(save_directory, pathlib.Path): save_directory = str(save_directory)
|
||||
if self.__llm_model__ is None: raise RuntimeError("Cannot 'save_pretrained' with unload model instance.")
|
||||
if self.bettertransformer and self.__llm_implementation__ == "pt": _object_setattr(self, "__llm_model__", t.cast("transformers.PreTrainedModel", self.__llm_model__).reverse_bettertransformer())
|
||||
if self.bettertransformer and self.__llm_implementation__ == "pt":
|
||||
_object_setattr(self, "__llm_model__", t.cast("transformers.PreTrainedModel", self.__llm_model__).reverse_bettertransformer())
|
||||
f(self, save_directory, **attrs)
|
||||
|
||||
return wrapper
|
||||
@@ -300,7 +299,13 @@ def _update_docstring(cls: LLM[M, T], fn: str) -> AnyCallable:
|
||||
setattr(cls, fn, original_fn)
|
||||
return original_fn
|
||||
def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]], None]:
|
||||
attributes = {"import_model": _wrapped_import_model, "load_model": _wrapped_load_model, "load_tokenizer": _wrapped_load_tokenizer, "llm_post_init": _wrapped_llm_post_init, "save_pretrained": _wrapped_save_pretrained}
|
||||
attributes = {
|
||||
"import_model": _wrapped_import_model,
|
||||
"load_model": _wrapped_load_model,
|
||||
"load_tokenizer": _wrapped_load_tokenizer,
|
||||
"llm_post_init": _wrapped_llm_post_init,
|
||||
"save_pretrained": _wrapped_save_pretrained
|
||||
}
|
||||
args: ListStr = []
|
||||
anns: DictStrAny = {}
|
||||
lines: ListStr = []
|
||||
@@ -372,7 +377,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
cd = cls.__dict__
|
||||
implementation, config_class_name = cls._infer_implementation_from_name(cls.__name__)
|
||||
cls.__llm_implementation__ = implementation
|
||||
config_class = AutoConfig.infer_class_from_name(config_class_name)
|
||||
config_class = openllm.AutoConfig.infer_class_from_name(config_class_name)
|
||||
if "__openllm_internal__" in cd:
|
||||
if "config_class" not in cd: cls.config_class = config_class
|
||||
elif "config_class" not in cd: raise RuntimeError("Missing required key 'config_class'. Make sure to define it within the LLM subclass.")
|
||||
@@ -532,11 +537,14 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
return f"{cls.__llm_implementation__}-{model_name}:{maybe_revision[0]}"
|
||||
|
||||
tag_name = f"{cls.__llm_implementation__}-{model_name}"
|
||||
if os.environ.get("OPENLLM_USE_LOCAL_LATEST", str(False)).upper() in ENV_VARS_TRUE_VALUES: return bentoml_cattr.unstructure(bentoml.models.get(f"{tag_name}{':'+model_version if model_version is not None else ''}").tag)
|
||||
if os.environ.get("OPENLLM_USE_LOCAL_LATEST", str(False)).upper() in ENV_VARS_TRUE_VALUES:
|
||||
return bentoml_cattr.unstructure(bentoml.models.get(f"{tag_name}{':'+model_version if model_version is not None else ''}").tag)
|
||||
if validate_is_path(model_id): model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id))
|
||||
else:
|
||||
from .serialisation.transformers._helpers import process_config
|
||||
model_version = getattr(process_config(model_id, trust_remote_code=cls.config_class.__openllm_trust_remote_code__, revision=first_not_none(model_version, default="main"))[0], "_commit_hash", None)
|
||||
model_version = getattr(
|
||||
process_config(model_id, trust_remote_code=cls.config_class.__openllm_trust_remote_code__, revision=first_not_none(model_version, default="main"))[0], "_commit_hash", None
|
||||
)
|
||||
if model_version is None: raise ValueError(f"Internal errors when parsing config for pretrained '{model_id}' ('commit_hash' not found)")
|
||||
return f"{tag_name}:{model_version}"
|
||||
|
||||
@@ -544,7 +552,22 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
def generate_tag(cls, *param_decls: t.Any, **attrs: t.Any) -> bentoml.Tag:
|
||||
return bentoml.Tag.from_taglike(cls._generate_tag_str(*param_decls, **attrs))
|
||||
|
||||
def __init__(self, *args: t.Any, model_id: str, llm_config: LLMConfig, bettertransformer: bool | None, quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None, _adapters_mapping: AdaptersMapping | None, _tag: bentoml.Tag, _quantize_method: t.Literal["int8", "int4", "gptq"] | None, _runtime: t.Literal["ggml", "transformers"], _model_version: str, _serialisation_format: t.Literal["safetensors", "legacy"], _local: bool, **attrs: t.Any,):
|
||||
def __init__(
|
||||
self,
|
||||
*args: t.Any,
|
||||
model_id: str,
|
||||
llm_config: LLMConfig,
|
||||
bettertransformer: bool | None,
|
||||
quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None,
|
||||
_adapters_mapping: AdaptersMapping | None,
|
||||
_tag: bentoml.Tag,
|
||||
_quantize_method: t.Literal["int8", "int4", "gptq"] | None,
|
||||
_runtime: t.Literal["ggml", "transformers"],
|
||||
_model_version: str,
|
||||
_serialisation_format: t.Literal["safetensors", "legacy"],
|
||||
_local: bool,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
"""Initialize the LLM with given pretrained model.
|
||||
|
||||
> [!WARNING]
|
||||
@@ -641,10 +664,28 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
# parsing tokenizer and model kwargs, as the hierachy is param pass > default
|
||||
normalized_model_kwds, normalized_tokenizer_kwds = normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
# NOTE: Save the args and kwargs for latter load
|
||||
self.__attrs_init__(llm_config, quantization_config, model_id, _runtime, args, {**model_kwds, **normalized_model_kwds}, {**tokenizer_kwds, **normalized_tokenizer_kwds}, _tag, _adapters_mapping, _model_version, _quantize_method, _serialisation_format, _local)
|
||||
self.__attrs_init__(
|
||||
llm_config,
|
||||
quantization_config,
|
||||
model_id,
|
||||
_runtime,
|
||||
args, {
|
||||
**model_kwds, **normalized_model_kwds
|
||||
}, {
|
||||
**tokenizer_kwds, **normalized_tokenizer_kwds
|
||||
},
|
||||
_tag,
|
||||
_adapters_mapping,
|
||||
_model_version,
|
||||
_quantize_method,
|
||||
_serialisation_format,
|
||||
_local
|
||||
)
|
||||
# handle trust_remote_code
|
||||
_from_env = os.getenv("TRUST_REMOTE_CODE", None)
|
||||
self.__llm_trust_remote_code__ = first_not_none(str(_from_env).upper() in ENV_VARS_TRUE_VALUES if _from_env else None, default=self._model_attrs.pop("trust_remote_code", self.config["trust_remote_code"]))
|
||||
self.__llm_trust_remote_code__ = first_not_none(
|
||||
str(_from_env).upper() in ENV_VARS_TRUE_VALUES if _from_env else None, default=self._model_attrs.pop("trust_remote_code", self.config["trust_remote_code"])
|
||||
)
|
||||
|
||||
self.llm_post_init()
|
||||
# we set it here so that we allow subclass to overwrite bettertransformer in llm_post_init
|
||||
@@ -654,7 +695,10 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
if _adapters_mapping and self.bettertransformer is True: self.bettertransformer = False
|
||||
|
||||
def __setattr__(self, attr: str, value: t.Any) -> None:
|
||||
if attr in _reserved_namespace: raise ForbiddenAttributeError(f"{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}.")
|
||||
if attr in _reserved_namespace:
|
||||
raise ForbiddenAttributeError(
|
||||
f"{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}."
|
||||
)
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
@property
|
||||
@@ -704,7 +748,15 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
return self._tag
|
||||
|
||||
def ensure_model_id_exists(self) -> bentoml.Model:
|
||||
return openllm.import_model(self.config["start_name"], model_id=self.model_id, model_version=self._model_version, runtime=self.runtime, implementation=self.__llm_implementation__, quantize=self._quantize_method, serialisation_format=self._serialisation_format)
|
||||
return openllm.import_model(
|
||||
self.config["start_name"],
|
||||
model_id=self.model_id,
|
||||
model_version=self._model_version,
|
||||
runtime=self.runtime,
|
||||
implementation=self.__llm_implementation__,
|
||||
quantize=self._quantize_method,
|
||||
serialisation_format=self._serialisation_format
|
||||
)
|
||||
|
||||
@property
|
||||
def _bentomodel(self) -> bentoml.Model:
|
||||
@@ -747,7 +799,9 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
try:
|
||||
model = model.to("cuda")
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f"Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.") from err
|
||||
raise OpenLLMException(
|
||||
f"Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information."
|
||||
) from err
|
||||
self.__llm_model__ = model
|
||||
return self.__llm_model__
|
||||
|
||||
@@ -758,7 +812,9 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
return self.__llm_tokenizer__
|
||||
|
||||
def _default_ft_config(self, _adapter_type: AdapterType, inference_mode: bool) -> FineTuneConfig:
|
||||
strategy = first_not_none(self.config["fine_tune_strategies"].get(_adapter_type), default=FineTuneConfig(adapter_type=t.cast("PeftType", _adapter_type), llm_config_class=self.config_class))
|
||||
strategy = first_not_none(
|
||||
self.config["fine_tune_strategies"].get(_adapter_type), default=FineTuneConfig(adapter_type=t.cast("PeftType", _adapter_type), llm_config_class=self.config_class)
|
||||
)
|
||||
return strategy.eval() if inference_mode else strategy.train()
|
||||
|
||||
def _transpose_adapter_mapping(self, inference_mode: bool = True, use_cache: bool = True) -> ResolvedAdaptersMapping:
|
||||
@@ -773,19 +829,24 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
for _adapter_type, _adapters_tuples in self._adapters_mapping.items():
|
||||
default_config = self._default_ft_config(_adapter_type, inference_mode)
|
||||
for adapter in _adapters_tuples:
|
||||
if not adapter.name and _converted_first_none: raise ValueError(f"{self.__class__.__name__} doesn't know how to resolve adapter_name None mapping: {adapter.adapter_id, adapter.config}")
|
||||
if not adapter.name and _converted_first_none:
|
||||
raise ValueError(f"{self.__class__.__name__} doesn't know how to resolve adapter_name None mapping: {adapter.adapter_id, adapter.config}")
|
||||
name = adapter.name
|
||||
if name is None:
|
||||
_converted_first_none = True
|
||||
name = "default"
|
||||
peft_config = default_config.with_config(**adapter.config).to_peft_config() if name == "default" else FineTuneConfig(adapter_type=t.cast("PeftType", _adapter_type), adapter_config=adapter.config, inference_mode=inference_mode, llm_config_class=self.config_class).to_peft_config()
|
||||
peft_config = default_config.with_config(**adapter.config).to_peft_config() if name == "default" else FineTuneConfig(
|
||||
adapter_type=t.cast("PeftType", _adapter_type), adapter_config=adapter.config, inference_mode=inference_mode, llm_config_class=self.config_class
|
||||
).to_peft_config()
|
||||
adapter_map[_adapter_type][name] = (peft_config, adapter.adapter_id)
|
||||
if self.__llm_adapter_map__ is None and use_cache: self.__llm_adapter_map__ = adapter_map
|
||||
return adapter_map
|
||||
|
||||
def prepare_for_training(self, adapter_type: AdapterType = "lora", use_gradient_checkpointing: bool = True, **attrs: t.Any) -> tuple[peft.PeftModel, T]:
|
||||
from peft import prepare_model_for_kbit_training
|
||||
peft_config = self.config["fine_tune_strategies"].get(adapter_type, FineTuneConfig(adapter_type=t.cast("PeftType", adapter_type), llm_config_class=self.config_class)).train().with_config(**attrs).to_peft_config()
|
||||
peft_config = self.config["fine_tune_strategies"].get(adapter_type, FineTuneConfig(adapter_type=t.cast("PeftType", adapter_type), llm_config_class=self.config_class)).train().with_config(
|
||||
**attrs
|
||||
).to_peft_config()
|
||||
wrapped_peft = peft.get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checkpointing), peft_config)
|
||||
if DEBUG: wrapped_peft.print_trainable_parameters()
|
||||
return wrapped_peft, self.tokenizer
|
||||
@@ -846,7 +907,13 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
|
||||
# order of these fields matter here, make sure to sync it with
|
||||
# openllm.models.auto.factory.BaseAutoLLMClass.for_model
|
||||
def to_runner(self, models: list[bentoml.Model] | None = None, max_batch_size: int | None = None, max_latency_ms: int | None = None, scheduling_strategy: type[bentoml.Strategy] = openllm_core.CascadingResourceStrategy) -> LLMRunner[M, T]:
|
||||
def to_runner(
|
||||
self,
|
||||
models: list[bentoml.Model] | None = None,
|
||||
max_batch_size: int | None = None,
|
||||
max_latency_ms: int | None = None,
|
||||
scheduling_strategy: type[bentoml.Strategy] = openllm_core.CascadingResourceStrategy
|
||||
) -> LLMRunner[M, T]:
|
||||
"""Convert this LLM into a Runner.
|
||||
|
||||
Args:
|
||||
@@ -879,7 +946,18 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
generate_iterator_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=False)))
|
||||
|
||||
# NOTE: returning the two langchain API's to the runner
|
||||
return llm_runner_class(self)(llm_runnable_class(self, embeddings_sig, generate_sig, generate_iterator_sig), name=self.runner_name, embedded=False, models=models, max_batch_size=max_batch_size, max_latency_ms=max_latency_ms, method_configs=bentoml_cattr.unstructure({"embeddings": embeddings_sig, "__call__": generate_sig, "generate": generate_sig, "generate_one": generate_sig, "generate_iterator": generate_iterator_sig}), scheduling_strategy=scheduling_strategy,)
|
||||
return llm_runner_class(self)(
|
||||
llm_runnable_class(self, embeddings_sig, generate_sig, generate_iterator_sig),
|
||||
name=self.runner_name,
|
||||
embedded=False,
|
||||
models=models,
|
||||
max_batch_size=max_batch_size,
|
||||
max_latency_ms=max_latency_ms,
|
||||
method_configs=bentoml_cattr.unstructure({
|
||||
"embeddings": embeddings_sig, "__call__": generate_sig, "generate": generate_sig, "generate_one": generate_sig, "generate_iterator": generate_iterator_sig
|
||||
}),
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
)
|
||||
|
||||
# NOTE: Scikit API
|
||||
def predict(self, prompt: str, **attrs: t.Any) -> t.Any:
|
||||
@@ -908,7 +986,18 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
pass
|
||||
return [it]
|
||||
|
||||
def generate_iterator(self, prompt: str, /, *, context_length: int | None = None, echo: bool = True, stream_interval: int = 2, stop: str | t.Iterable[str] | None = None, stop_token_ids: list[int] | None = None, **attrs: t.Any) -> t.Iterator[t.Any]:
|
||||
def generate_iterator(
|
||||
self,
|
||||
prompt: str,
|
||||
/,
|
||||
*,
|
||||
context_length: int | None = None,
|
||||
echo: bool = True,
|
||||
stream_interval: int = 2,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
**attrs: t.Any
|
||||
) -> t.Iterator[t.Any]:
|
||||
# NOTE: encoder-decoder models will need to implement their own generate_iterator for now
|
||||
# inspired from fastchat's generate_stream_func
|
||||
from ._generation import prepare_logits_processor, get_context_length, is_partial_stop
|
||||
@@ -937,7 +1026,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
last_token_logits = logits_processor(torch.as_tensor([output_ids], device=logits.device) if self.config["repetition_penalty"] > 1.0 else None, logits[:, -1, :])[0] if logits_processor else logits[0, -1, :]
|
||||
last_token_logits = logits_processor(torch.as_tensor([output_ids], device=logits.device)
|
||||
if self.config["repetition_penalty"] > 1.0 else None, logits[:, -1, :])[0] if logits_processor else logits[0, -1, :]
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
if self.device.type == "mps": last_token_logits = last_token_logits.float().to("cpu")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user