mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-18 22:55:08 -05:00
chore: cleanup loader (#729)
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -83,8 +83,7 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
if not is_ctranslate_available():
|
||||
raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
|
||||
if not is_ctranslate_available(): raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
@@ -127,8 +126,7 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
if not is_vllm_available():
|
||||
raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
|
||||
if not is_vllm_available(): raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
|
||||
import vllm
|
||||
|
||||
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
|
||||
@@ -180,13 +178,8 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
if adapter_name is not None:
|
||||
self.model.set_adapter(adapter_name)
|
||||
async for generation_output in self.forward(prompt_token_ids, request_id, list(stop), **attrs):
|
||||
yield generation_output.model_dump_json()
|
||||
|
||||
async def forward(self, prompt_token_ids, request_id, stop, **attrs):
|
||||
from ._generation import get_context_length, is_partial_stop, prepare_logits_processor
|
||||
from ._generation import get_context_length, prepare_logits_processor
|
||||
if adapter_name is not None: self.model.set_adapter(adapter_name)
|
||||
|
||||
max_new_tokens = attrs.pop('max_new_tokens', 256)
|
||||
context_length = attrs.pop('context_length', None)
|
||||
@@ -224,6 +217,8 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
finish_reason = None
|
||||
prompt_logprobs = []
|
||||
prompt_token_indices = []
|
||||
stopped = False
|
||||
sample_logprobs: SampleLogprobs = [None] # The first token has no logprobs
|
||||
|
||||
for i in range(config['max_new_tokens']):
|
||||
if i == 0: # prefill
|
||||
@@ -247,10 +242,9 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
if logits_processor:
|
||||
if config['repetition_penalty'] > 1.0:
|
||||
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.model.device)
|
||||
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
|
||||
else:
|
||||
tmp_output_ids = None
|
||||
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
||||
@@ -272,11 +266,11 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
|
||||
token = tokens[0]
|
||||
output_token_ids.append(token)
|
||||
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
|
||||
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
|
||||
sample_logprobs: SampleLogprobs = []
|
||||
token_logprobs = logprobs[token].item()
|
||||
cumulative_logprob += token_logprobs
|
||||
if config['logprobs']:
|
||||
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
|
||||
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
|
||||
token_logprobs = logprobs[token].item()
|
||||
cumulative_logprob += token_logprobs
|
||||
|
||||
if config['prompt_logprobs']:
|
||||
for token_id in prompt_token_ids:
|
||||
@@ -296,40 +290,32 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
partially_stopped = False
|
||||
if len(stop) > 0:
|
||||
for it in stop:
|
||||
pos = text.rfind(it, rfind_start)
|
||||
if pos != -1:
|
||||
text, stopped = text[:pos], True
|
||||
break
|
||||
else:
|
||||
partially_stopped = is_partial_stop(text, it)
|
||||
if partially_stopped:
|
||||
break
|
||||
|
||||
if config['logprobs']:
|
||||
sample_logprobs.append({token: token_logprobs})
|
||||
if config['logprobs']: sample_logprobs.append({token: token_logprobs})
|
||||
|
||||
if not partially_stopped:
|
||||
# TODO: calculate prompt_logprobs
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
finished=False,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=tmp_output_ids,
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=sample_logprobs if config['logprobs'] else None,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
request_id=request_id,
|
||||
)
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
finished=False,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=tmp_output_ids,
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=sample_logprobs if config['logprobs'] else None,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
|
||||
request_id=request_id,
|
||||
).model_dump_json()
|
||||
if stopped:
|
||||
break
|
||||
else:
|
||||
@@ -343,16 +329,16 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=output_token_ids[input_len:],
|
||||
token_ids=output_token_ids,
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=sample_logprobs if config['logprobs'] else None,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
|
||||
request_id=request_id,
|
||||
)
|
||||
).model_dump_json()
|
||||
|
||||
# Clean
|
||||
del past_key_values, out
|
||||
|
||||
@@ -4,7 +4,7 @@ import orjson, torch, transformers, bentoml, openllm
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import first_not_none, is_autogptq_available
|
||||
from openllm_core.utils import first_not_none, is_autogptq_available, is_flash_attn_2_available
|
||||
|
||||
from ._helpers import get_tokenizer, infer_autoclass_from_llm, process_config
|
||||
from .weights import HfIgnore
|
||||
@@ -122,18 +122,36 @@ def load_model(llm, *decls, **attrs):
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
|
||||
# TODO: investigate load with flash attention
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, use_flash_attention_2=is_flash_attn_2_available(), **attrs
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
else:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
use_flash_attention_2=is_flash_attn_2_available(),
|
||||
**attrs,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
|
||||
check_unintialised_params(model)
|
||||
return model
|
||||
|
||||
@@ -35,6 +35,7 @@ from openllm_core.utils import (
|
||||
is_bentoml_available as is_bentoml_available,
|
||||
is_bitsandbytes_available as is_bitsandbytes_available,
|
||||
is_ctranslate_available as is_ctranslate_available,
|
||||
is_flash_attn_2_available as is_flash_attn_2_available,
|
||||
is_grpc_available as is_grpc_available,
|
||||
is_jupyter_available as is_jupyter_available,
|
||||
is_jupytext_available as is_jupytext_available,
|
||||
@@ -55,7 +56,7 @@ from openllm_core.utils import (
|
||||
)
|
||||
from openllm_core.utils.serde import converter as converter
|
||||
|
||||
from .._llm import LLM
|
||||
from ._llm import LLM
|
||||
|
||||
def available_devices() -> Tuple[str, ...]: ...
|
||||
def device_count() -> int: ...
|
||||
|
||||
Reference in New Issue
Block a user