chore: cleanup loader (#729)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-22 21:51:51 -05:00
committed by GitHub
parent 5442d9cd10
commit 52a44b1bfa
8 changed files with 125 additions and 264 deletions

View File

@@ -83,8 +83,7 @@ class CTranslateRunnable(bentoml.Runnable):
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self, llm):
if not is_ctranslate_available():
raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
if not is_ctranslate_available(): raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
@bentoml.Runnable.method(batchable=False)
@@ -127,8 +126,7 @@ class vLLMRunnable(bentoml.Runnable):
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self, llm):
if not is_vllm_available():
raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
if not is_vllm_available(): raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
import vllm
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
@@ -180,13 +178,8 @@ class PyTorchRunnable(bentoml.Runnable):
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
if adapter_name is not None:
self.model.set_adapter(adapter_name)
async for generation_output in self.forward(prompt_token_ids, request_id, list(stop), **attrs):
yield generation_output.model_dump_json()
async def forward(self, prompt_token_ids, request_id, stop, **attrs):
from ._generation import get_context_length, is_partial_stop, prepare_logits_processor
from ._generation import get_context_length, prepare_logits_processor
if adapter_name is not None: self.model.set_adapter(adapter_name)
max_new_tokens = attrs.pop('max_new_tokens', 256)
context_length = attrs.pop('context_length', None)
@@ -224,6 +217,8 @@ class PyTorchRunnable(bentoml.Runnable):
finish_reason = None
prompt_logprobs = []
prompt_token_indices = []
stopped = False
sample_logprobs: SampleLogprobs = [None] # The first token has no logprobs
for i in range(config['max_new_tokens']):
if i == 0: # prefill
@@ -247,10 +242,9 @@ class PyTorchRunnable(bentoml.Runnable):
)
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if config['repetition_penalty'] > 1.0:
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.model.device)
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
@@ -272,11 +266,11 @@ class PyTorchRunnable(bentoml.Runnable):
token = tokens[0]
output_token_ids.append(token)
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
sample_logprobs: SampleLogprobs = []
token_logprobs = logprobs[token].item()
cumulative_logprob += token_logprobs
if config['logprobs']:
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
token_logprobs = logprobs[token].item()
cumulative_logprob += token_logprobs
if config['prompt_logprobs']:
for token_id in prompt_token_ids:
@@ -296,40 +290,32 @@ class PyTorchRunnable(bentoml.Runnable):
clean_up_tokenization_spaces=True,
)
partially_stopped = False
if len(stop) > 0:
for it in stop:
pos = text.rfind(it, rfind_start)
if pos != -1:
text, stopped = text[:pos], True
break
else:
partially_stopped = is_partial_stop(text, it)
if partially_stopped:
break
if config['logprobs']:
sample_logprobs.append({token: token_logprobs})
if config['logprobs']: sample_logprobs.append({token: token_logprobs})
if not partially_stopped:
# TODO: calculate prompt_logprobs
yield GenerationOutput(
prompt='',
finished=False,
outputs=[
CompletionChunk(
index=0,
text=text,
token_ids=tmp_output_ids,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=None,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
request_id=request_id,
)
yield GenerationOutput(
prompt='',
finished=False,
outputs=[
CompletionChunk(
index=0,
text=text,
token_ids=tmp_output_ids,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=None,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
request_id=request_id,
).model_dump_json()
if stopped:
break
else:
@@ -343,16 +329,16 @@ class PyTorchRunnable(bentoml.Runnable):
CompletionChunk(
index=0,
text=text,
token_ids=output_token_ids[input_len:],
token_ids=output_token_ids,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=finish_reason,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
request_id=request_id,
)
).model_dump_json()
# Clean
del past_key_values, out

View File

@@ -4,7 +4,7 @@ import orjson, torch, transformers, bentoml, openllm
from huggingface_hub import snapshot_download
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import first_not_none, is_autogptq_available
from openllm_core.utils import first_not_none, is_autogptq_available, is_flash_attn_2_available
from ._helpers import get_tokenizer, infer_autoclass_from_llm, process_config
from .weights import HfIgnore
@@ -122,18 +122,36 @@ def load_model(llm, *decls, **attrs):
if llm.config['model_type'] != 'causal_lm':
raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
# TODO: investigate load with flash attention
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
)
try:
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, use_flash_attention_2=is_flash_attn_2_available(), **attrs
)
except Exception as err:
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
)
else:
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
**attrs,
)
try:
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
use_flash_attention_2=is_flash_attn_2_available(),
**attrs,
)
except Exception as err:
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
**attrs,
)
check_unintialised_params(model)
return model

View File

@@ -35,6 +35,7 @@ from openllm_core.utils import (
is_bentoml_available as is_bentoml_available,
is_bitsandbytes_available as is_bitsandbytes_available,
is_ctranslate_available as is_ctranslate_available,
is_flash_attn_2_available as is_flash_attn_2_available,
is_grpc_available as is_grpc_available,
is_jupyter_available as is_jupyter_available,
is_jupytext_available as is_jupytext_available,
@@ -55,7 +56,7 @@ from openllm_core.utils import (
)
from openllm_core.utils.serde import converter as converter
from .._llm import LLM
from ._llm import LLM
def available_devices() -> Tuple[str, ...]: ...
def device_count() -> int: ...