mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-18 22:55:08 -05:00
refactor: update runner helpers and add max_model_len (#712)
* chore(runner): cleanup unecessary checks for runnable backend Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: saving llm reference to runner Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: correct inject item Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update support for max_seq_len Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix: correct max_model_len Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update and warning backward compatibility Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: remove unused sets Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import types
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
@@ -11,7 +10,6 @@ import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import (
|
||||
AdapterMap,
|
||||
@@ -34,7 +32,6 @@ from openllm_core.utils import (
|
||||
ReprMixin,
|
||||
apply,
|
||||
codegen,
|
||||
converter,
|
||||
first_not_none,
|
||||
flatten_attrs,
|
||||
gen_random_uuid,
|
||||
@@ -131,6 +128,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_adapter_map: AdapterMap | None
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_max_model_len: int | None
|
||||
_prompt_template: PromptTemplate | None
|
||||
_system_message: str | None
|
||||
|
||||
@@ -163,6 +161,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
embedded=False,
|
||||
dtype='auto',
|
||||
low_cpu_mem_usage=True,
|
||||
max_model_len=None,
|
||||
_eager=True,
|
||||
**attrs,
|
||||
):
|
||||
@@ -192,6 +191,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
adapter_map=_resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
|
||||
serialisation=serialisation,
|
||||
local=_local,
|
||||
max_model_len=max_model_len,
|
||||
prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template,
|
||||
system_message=system_message,
|
||||
LLM__model_attrs=model_attrs,
|
||||
@@ -327,7 +327,8 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
return self.__llm_tokenizer__
|
||||
@property
|
||||
def runner(self):
|
||||
if self.__llm_runner__ is None:self.__llm_runner__=_RunnerFactory(self)
|
||||
from ._runners import runner
|
||||
if self.__llm_runner__ is None:self.__llm_runner__=runner(self)
|
||||
return self.__llm_runner__
|
||||
def prepare(self,adapter_type='lora',use_gradient_checking=True,**attrs):
|
||||
if self.__llm_backend__!='pt':raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
|
||||
@@ -421,6 +422,8 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
async def generate(
|
||||
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
|
||||
) -> GenerationOutput:
|
||||
if adapter_name is not None and self.__llm_backend__ != 'pt':
|
||||
raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
final_result = None
|
||||
@@ -446,6 +449,9 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
from bentoml._internal.runner.runner_handle import DummyRunnerHandle
|
||||
|
||||
if adapter_name is not None and self.__llm_backend__ != 'pt':
|
||||
raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
|
||||
|
||||
if isinstance(self.runner._runner_handle, DummyRunnerHandle):
|
||||
if os.getenv('BENTO_PATH') is not None:
|
||||
raise RuntimeError('Runner client failed to set up correctly.')
|
||||
@@ -487,82 +493,3 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
|
||||
|
||||
def _RunnerFactory(
|
||||
llm, /, models=None, max_batch_size=None, max_latency_ms=None, scheduling_strategy=None, *, backend=None
|
||||
):
|
||||
from ._runners import runnable
|
||||
|
||||
if scheduling_strategy is None:
|
||||
from ._strategies import CascadingResourceStrategy
|
||||
|
||||
scheduling_strategy = CascadingResourceStrategy
|
||||
|
||||
backend = first_not_none(getenv('backend', default=backend), default=llm.__llm_backend__)
|
||||
|
||||
models = models if models is not None else []
|
||||
try:
|
||||
models.append(llm.bentomodel)
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err
|
||||
|
||||
if llm._prompt_template:
|
||||
prompt_template = llm._prompt_template.to_string()
|
||||
elif hasattr(llm.config, 'default_prompt_template'):
|
||||
prompt_template = llm.config.default_prompt_template
|
||||
else:
|
||||
prompt_template = None
|
||||
if llm._system_message:
|
||||
system_message = llm._system_message
|
||||
elif hasattr(llm.config, 'default_system_message'):
|
||||
system_message = llm.config.default_system_message
|
||||
else:
|
||||
system_message = None
|
||||
return types.new_class(
|
||||
llm.config.__class__.__name__[:-6] + 'Runner',
|
||||
(bentoml.Runner,),
|
||||
exec_body=lambda ns: ns.update(
|
||||
{
|
||||
'llm_type': llm.llm_type,
|
||||
'identifying_params': llm.identifying_params,
|
||||
'llm_tag': llm.tag,
|
||||
'llm': llm,
|
||||
'config': llm.config,
|
||||
'backend': backend,
|
||||
'__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}',
|
||||
'__module__': llm.__module__,
|
||||
'__repr__': ReprMixin.__repr__,
|
||||
'__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}),
|
||||
'__repr_args__': lambda _: (
|
||||
(
|
||||
'runner_methods',
|
||||
{
|
||||
method.name: {
|
||||
'batchable': method.config.batchable,
|
||||
'batch_dim': method.config.batch_dim if method.config.batchable else None,
|
||||
}
|
||||
for method in _.runner_methods
|
||||
},
|
||||
),
|
||||
('config', llm.config.model_dump(flatten=True)),
|
||||
('llm_type', llm.llm_type),
|
||||
('backend', backend),
|
||||
('llm_tag', llm.tag),
|
||||
),
|
||||
'has_adapters': llm.has_adapters,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
}
|
||||
),
|
||||
)(
|
||||
runnable(llm, backend),
|
||||
name=llm.runner_name,
|
||||
embedded=False,
|
||||
models=models,
|
||||
max_batch_size=max_batch_size,
|
||||
max_latency_ms=max_latency_ms,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
runnable_init_params={'llm': llm},
|
||||
method_configs=converter.unstructure({'generate_iterator': ModelSignature(batchable=False)}),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import gc
|
||||
import traceback
|
||||
import types
|
||||
import typing as t
|
||||
|
||||
import torch
|
||||
@@ -9,23 +10,90 @@ import bentoml
|
||||
import openllm
|
||||
from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import first_not_none, getenv, is_ctranslate_available
|
||||
from openllm_core.utils import ReprMixin, is_ctranslate_available, is_vllm_available
|
||||
|
||||
__all__ = ['runnable']
|
||||
__all__ = ['runner']
|
||||
|
||||
_registry = {}
|
||||
|
||||
|
||||
def runnable(llm, backend=None):
|
||||
backend = first_not_none(getenv('backend', default=backend), default=llm._cascade_backend())
|
||||
if backend == 'vllm':
|
||||
return vLLMRunnable
|
||||
elif backend == 'pt':
|
||||
return PyTorchRunnable
|
||||
elif backend == 'ctranslate':
|
||||
return CTranslateRunnable
|
||||
def registry(cls=None, *, alias=None):
|
||||
def decorator(_cls):
|
||||
_registry[_cls.__name__[:-8].lower() if alias is None else alias] = _cls
|
||||
return _cls
|
||||
|
||||
if cls is None:
|
||||
return decorator
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def runner(llm):
|
||||
from ._strategies import CascadingResourceStrategy
|
||||
|
||||
try:
|
||||
models = [llm.bentomodel]
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err
|
||||
|
||||
if llm._prompt_template:
|
||||
prompt_template = llm._prompt_template.to_string()
|
||||
elif hasattr(llm.config, 'default_prompt_template'):
|
||||
prompt_template = llm.config.default_prompt_template
|
||||
else:
|
||||
raise OpenLLMException(f'Unsupported backend: {backend}')
|
||||
prompt_template = None
|
||||
if llm._system_message:
|
||||
system_message = llm._system_message
|
||||
elif hasattr(llm.config, 'default_system_message'):
|
||||
system_message = llm.config.default_system_message
|
||||
else:
|
||||
system_message = None
|
||||
|
||||
return types.new_class(
|
||||
llm.config.__class__.__name__[:-6] + 'Runner',
|
||||
(bentoml.Runner,),
|
||||
exec_body=lambda ns: ns.update(
|
||||
{
|
||||
'llm_type': llm.llm_type,
|
||||
'identifying_params': llm.identifying_params,
|
||||
'llm_tag': llm.tag,
|
||||
'llm': llm,
|
||||
'config': llm.config,
|
||||
'backend': llm.__llm_backend__,
|
||||
'__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}',
|
||||
'__module__': llm.__module__,
|
||||
'__repr__': ReprMixin.__repr__,
|
||||
'__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}),
|
||||
'__repr_args__': lambda _: (
|
||||
(
|
||||
'runner_methods',
|
||||
{
|
||||
method.name: {
|
||||
'batchable': method.config.batchable,
|
||||
'batch_dim': method.config.batch_dim if method.config.batchable else None,
|
||||
}
|
||||
for method in _.runner_methods
|
||||
},
|
||||
),
|
||||
('config', llm.config.model_dump(flatten=True)),
|
||||
('llm_type', llm.llm_type),
|
||||
('backend', llm.__llm_backend__),
|
||||
('llm_tag', llm.tag),
|
||||
),
|
||||
'has_adapters': llm.has_adapters,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
}
|
||||
),
|
||||
)(
|
||||
_registry[llm.__llm_backend__],
|
||||
name=llm.runner_name,
|
||||
models=models,
|
||||
scheduling_strategy=CascadingResourceStrategy,
|
||||
runnable_init_params={'llm': llm},
|
||||
)
|
||||
|
||||
|
||||
@registry
|
||||
class CTranslateRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
@@ -33,40 +101,23 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
def __init__(self, llm):
|
||||
if not is_ctranslate_available():
|
||||
raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`')
|
||||
self.config = llm.config
|
||||
self.model = llm.model
|
||||
self.tokenizer = llm.tokenizer
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.modle, llm.tokenizer
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
if adapter_name is not None:
|
||||
raise NotImplementedError('Adapter is not supported with CTranslate.')
|
||||
|
||||
stop_ = set()
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
|
||||
config = self.config.model_construct_env(stop=list(stop_), **attrs)
|
||||
sampling_params = dict(
|
||||
max_length=config['max_new_tokens'],
|
||||
min_length=config['min_length'],
|
||||
sampling_topk=config['top_k'],
|
||||
sampling_topp=config['top_p'],
|
||||
sampling_temperature=config['temperature'],
|
||||
return_log_prob=config['logprobs'] > 0,
|
||||
repetition_penalty=config['repetition_penalty'],
|
||||
no_repeat_ngram_size=config['no_repeat_ngram_size'],
|
||||
end_token=config['stop'],
|
||||
)
|
||||
cumulative_logprob = 0.0
|
||||
output_token_ids = list(prompt_token_ids)
|
||||
input_len = len(prompt_token_ids)
|
||||
async for request_output in self.model.async_generate_tokens(
|
||||
self.tokenizer.convert_ids_to_tokens(prompt_token_ids), **sampling_params
|
||||
):
|
||||
cumulative_logprob += request_output.log_prob if config['logprobs'] else 0.0
|
||||
config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm)
|
||||
cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids)
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids)
|
||||
|
||||
async for request_output in self.model.async_generate_tokens(tokens, **sampling_params):
|
||||
if config['logprobs']:
|
||||
cumulative_logprob += request_output.log_prob
|
||||
output_token_ids.append(request_output.token_id)
|
||||
text = self.tokenizer.decode(
|
||||
output_token_ids[input_len:],
|
||||
@@ -92,34 +143,34 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
).model_dump_json()
|
||||
|
||||
|
||||
@registry
|
||||
class vLLMRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
try:
|
||||
import vllm
|
||||
except ImportError:
|
||||
raise OpenLLMException('vLLM is not installed. Please install it via `pip install "openllm[vllm]"`.') from None
|
||||
self.config = llm.config
|
||||
if not is_vllm_available():
|
||||
raise OpenLLMException('vLLM is not installed. Please install it via `pip install "openllm[vllm]"`.')
|
||||
import vllm
|
||||
|
||||
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
quantization = None
|
||||
if llm.quantise and llm.quantise in {'awq', 'squeezellm'}:
|
||||
quantization = llm.quantise
|
||||
|
||||
try:
|
||||
self.model = vllm.AsyncLLMEngine.from_engine_args(
|
||||
vllm.AsyncEngineArgs(
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
model=llm.bentomodel.path,
|
||||
tokenizer=llm.bentomodel.path,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
dtype=llm._torch_dtype,
|
||||
quantization=quantization,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
max_model_len=llm._max_model_len,
|
||||
quantization=llm.quantise if llm.quantise and llm.quantise in {'awq', 'squeezellm'} else None,
|
||||
)
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -128,21 +179,13 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
if adapter_name is not None:
|
||||
raise NotImplementedError('Adapter is not supported with vLLM.')
|
||||
stop_ = set()
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
|
||||
temperature = attrs.pop('temperature', self.config['temperature'])
|
||||
top_p = attrs.pop('top_p', self.config['top_p'])
|
||||
if temperature <= 1e-5:
|
||||
top_p = 1.0
|
||||
sampling_params = self.config.model_construct_env(
|
||||
stop=list(stop_), temperature=temperature, top_p=top_p, **attrs
|
||||
).to_sampling_config()
|
||||
config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm)
|
||||
|
||||
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
|
||||
# XXX: Need to write a hook for serialisation None correctly
|
||||
@@ -151,32 +194,31 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
yield GenerationOutput.from_vllm(request_output).model_dump_json()
|
||||
|
||||
|
||||
@registry(alias='pt')
|
||||
class PyTorchRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
self.model = llm.model
|
||||
self.tokenizer = llm.tokenizer
|
||||
self.config = llm.config
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
self.is_encoder_decoder = llm.model.config.is_encoder_decoder
|
||||
if hasattr(llm.model, 'device'):
|
||||
self.device = llm.model.device
|
||||
else:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.is_encoder_decoder = llm.model.config.is_encoder_decoder
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
if adapter_name is not None:
|
||||
self.model.set_adapter(adapter_name)
|
||||
|
||||
stop_ = set()
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
async for generation_output in self.forward(
|
||||
prompt_token_ids=prompt_token_ids, request_id=request_id, stop=list(stop_), **attrs
|
||||
):
|
||||
|
||||
async for generation_output in self.forward(prompt_token_ids, request_id, list(stop_), **attrs):
|
||||
yield generation_output.model_dump_json()
|
||||
|
||||
async def forward(self, prompt_token_ids, request_id, stop, **attrs):
|
||||
|
||||
@@ -15,10 +15,13 @@ from typing import (
|
||||
final,
|
||||
)
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from bentoml import Model, Strategy, Tag
|
||||
from bentoml._internal.runner.runner_handle import RunnerHandle
|
||||
from openllm_core import LLMConfig
|
||||
from openllm_core._typing_compat import LiteralBackend, M, T, overload
|
||||
from openllm_core._typing_compat import LiteralBackend, M, T
|
||||
|
||||
from ._llm import LLM
|
||||
|
||||
@@ -27,24 +30,21 @@ try:
|
||||
except ImportError:
|
||||
AsyncLLMEngine = Any
|
||||
|
||||
try:
|
||||
from transformers import PreTrainedModel
|
||||
except ImportError:
|
||||
PreTrainedModel = Any
|
||||
|
||||
try:
|
||||
from ctranslate2 import Generator, Translator
|
||||
except ImportError:
|
||||
Translator = Any
|
||||
Generator = Any
|
||||
Translator = Generator = Any
|
||||
|
||||
Mo = TypeVar('Mo')
|
||||
To = TypeVar('To')
|
||||
|
||||
class _Runnable(Protocol[Mo]):
|
||||
class _Runnable(Protocol[Mo, To]):
|
||||
SUPPORTED_RESOURCES: Tuple[Literal['nvidia.com/gpu', 'amd.com/gpu', 'cpu'], ...] = ...
|
||||
SUPPORTS_CPU_MULTI_THREADING: bool = ...
|
||||
llm: LLM[Mo, To] = ...
|
||||
config: LLMConfig = ...
|
||||
model: Mo = ...
|
||||
tokenizer: To = ...
|
||||
def __init__(self, llm: LLM[Mo, T]) -> None: ...
|
||||
async def generate_iterator(
|
||||
self,
|
||||
@@ -61,42 +61,25 @@ Ret = TypeVar('Ret')
|
||||
class RunnerMethod(Generic[In, Ret]): ...
|
||||
|
||||
@final
|
||||
class vLLMRunnable(_Runnable[AsyncLLMEngine]): ...
|
||||
class vLLMRunnable(_Runnable[AsyncLLMEngine, PreTrainedTokenizer]): ...
|
||||
|
||||
@final
|
||||
class CTranslateRunnable(_Runnable[Union[Translator, Generator]]):
|
||||
tokenizer: Any
|
||||
class CTranslateRunnable(_Runnable[Union[Translator, Generator], PreTrainedTokenizer]): ...
|
||||
|
||||
@final
|
||||
class PyTorchRunnable(_Runnable[PreTrainedModel]):
|
||||
tokenizer: Any
|
||||
async def forward(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
request_id: str,
|
||||
stop: Iterable[str],
|
||||
adapter_name: Optional[str] = ...,
|
||||
**attrs: Any,
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
class PyTorchRunnable(_Runnable[PreTrainedModel, PreTrainedTokenizer]):
|
||||
is_encoder_decoder: bool = ...
|
||||
device: torch.device = ...
|
||||
|
||||
@overload
|
||||
def runnable(llm: LLM[M, T], backend: Literal['vllm']) -> Type[vLLMRunnable]: ...
|
||||
@overload
|
||||
def runnable(llm: LLM[M, T], backend: Literal['pt']) -> Type[PyTorchRunnable]: ...
|
||||
@overload
|
||||
def runnable(llm: LLM[M, T], backend: Literal['ctranslate']) -> Type[CTranslateRunnable]: ...
|
||||
@overload
|
||||
def runnable(
|
||||
llm: LLM[M, T], backend: Optional[str] = ...
|
||||
) -> Type[Union[vLLMRunnable, PyTorchRunnable, CTranslateRunnable]]: ...
|
||||
def runner(llm: LLM[M, T]) -> Runner[M, T]: ...
|
||||
|
||||
class Runner(Protocol[Mo, T]):
|
||||
class Runner(Protocol[Mo, To]):
|
||||
__doc__: str = ...
|
||||
__module__: str = ...
|
||||
llm_type: str = ...
|
||||
llm_tag: Tag = ...
|
||||
identifying_params: Dict[str, Any] = ...
|
||||
llm: LLM[Mo, T] = ...
|
||||
llm: LLM[Mo, To] = ...
|
||||
config: LLMConfig = ...
|
||||
backend: LiteralBackend = ...
|
||||
has_adapters: bool = ...
|
||||
@@ -115,7 +98,7 @@ class Runner(Protocol[Mo, T]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable_class: Type[_Runnable[Mo]],
|
||||
runnable_class: Type[_Runnable[Mo, To]],
|
||||
*,
|
||||
runnable_init_params: Optional[Dict[str, Any]] = ...,
|
||||
name: Optional[str] = ...,
|
||||
@@ -130,7 +113,7 @@ class Runner(Protocol[Mo, T]):
|
||||
name: str = ...
|
||||
models: List[Model] = ...
|
||||
resource_config: Dict[str, Any]
|
||||
runnable_class: Type[_Runnable[Mo]]
|
||||
runnable_class: Type[_Runnable[Mo, To]]
|
||||
embedded: bool
|
||||
runner_methods: List[RunnerMethod[Any, Any]]
|
||||
scheduling_strategy: Type[Strategy]
|
||||
|
||||
Reference in New Issue
Block a user