mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-05 07:36:15 -05:00
chore(style): add one blank line
to conform with Google style Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -13,6 +13,7 @@ if t.TYPE_CHECKING: import torch
|
||||
|
||||
_GENERIC_EMBEDDING_ID = 'sentence-transformers/all-MiniLM-L6-v2'
|
||||
_BENTOMODEL_ID = 'sentence-transformers--all-MiniLM-L6-v2'
|
||||
|
||||
def get_or_download(ids: str = _BENTOMODEL_ID) -> bentoml.Model:
|
||||
try:
|
||||
return bentoml.transformers.get(ids)
|
||||
@@ -36,6 +37,7 @@ def get_or_download(ids: str = _BENTOMODEL_ID) -> bentoml.Model:
|
||||
_GENERIC_EMBEDDING_ID, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=['*.safetensors', '*.h5', '*.ot', '*.pdf', '*.md', '.gitattributes', 'LICENSE.txt']
|
||||
)
|
||||
return bentomodel
|
||||
|
||||
class GenericEmbeddingRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
@@ -67,4 +69,5 @@ class GenericEmbeddingRunnable(bentoml.Runnable):
|
||||
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
__all__ = ['GenericEmbeddingRunnable']
|
||||
|
||||
@@ -8,6 +8,7 @@ if t.TYPE_CHECKING: import torch, openllm
|
||||
# reexport from transformers
|
||||
LogitsProcessorList = transformers.LogitsProcessorList
|
||||
StoppingCriteriaList = transformers.StoppingCriteriaList
|
||||
|
||||
class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerBase | transformers.PreTrainedTokenizerFast):
|
||||
if isinstance(stop_sequences, str): stop_sequences = [stop_sequences]
|
||||
@@ -15,9 +16,11 @@ class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **_: t.Any) -> bool:
|
||||
return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||
|
||||
class StopOnTokens(transformers.StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **_: t.Any) -> bool:
|
||||
return input_ids[0][-1] in {50278, 50279, 50277, 1, 0}
|
||||
|
||||
def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsProcessorList:
|
||||
generation_config = config.generation_config
|
||||
logits_processor = transformers.LogitsProcessorList()
|
||||
@@ -26,16 +29,20 @@ def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsPr
|
||||
if 1e-8 <= generation_config['top_p']: logits_processor.append(transformers.TopPLogitsWarper(generation_config['top_p']))
|
||||
if generation_config['top_k'] > 0: logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
|
||||
return logits_processor
|
||||
|
||||
# NOTE: The ordering here is important. Some models have two of these and we have a preference for which value gets used.
|
||||
SEQLEN_KEYS = ['max_sequence_length', 'seq_length', 'max_position_embeddings', 'max_seq_len', 'model_max_length']
|
||||
|
||||
def get_context_length(config: transformers.PretrainedConfig) -> int:
|
||||
rope_scaling = getattr(config, 'rope_scaling', None)
|
||||
rope_scaling_factor = config.rope_scaling['factor'] if rope_scaling else 1.0
|
||||
for key in SEQLEN_KEYS:
|
||||
if getattr(config, key, None) is not None: return int(rope_scaling_factor * getattr(config, key))
|
||||
return 2048
|
||||
|
||||
def is_sentence_complete(output: str) -> bool:
|
||||
return output.endswith(('.', '?', '!', '...', '。', '?', '!', '…', '"', "'", '”'))
|
||||
|
||||
def is_partial_stop(output: str, stop_str: str) -> bool:
|
||||
'''Check whether the output contains a partial stop str.'''
|
||||
for i in range(0, min(len(output), len(stop_str))):
|
||||
|
||||
@@ -49,15 +49,19 @@ else:
|
||||
ResolvedAdaptersMapping = t.Dict[AdapterType, t.Dict[str, t.Tuple['peft.PeftConfig', str]]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelSignatureDict(t.TypedDict, total=False):
|
||||
batchable: bool
|
||||
batch_dim: t.Union[t.Tuple[int, int], int]
|
||||
input_spec: NotRequired[t.Union[t.Any, t.Tuple[t.Any]]]
|
||||
output_spec: NotRequired[t.Any]
|
||||
|
||||
def normalise_model_name(name: str) -> str:
|
||||
return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else re.sub('[^a-zA-Z0-9]+', '-', name)
|
||||
|
||||
# the below is similar to peft.utils.other.CONFIG_NAME
|
||||
PEFT_CONFIG_NAME = 'adapter_config.json'
|
||||
|
||||
def resolve_peft_config_type(adapter_map: dict[str, str | None]) -> AdaptersMapping:
|
||||
'''Resolve the type of the PeftConfig given the adapter_map.
|
||||
|
||||
@@ -88,7 +92,9 @@ def resolve_peft_config_type(adapter_map: dict[str, str | None]) -> AdaptersMapp
|
||||
if _peft_type not in resolved: resolved[_peft_type] = ()
|
||||
resolved[_peft_type] += (_AdaptersTuple((path_or_adapter_id, resolve_name, resolved_config)),)
|
||||
return resolved
|
||||
|
||||
_reserved_namespace = {'config_class', 'model', 'tokenizer', 'import_kwargs'}
|
||||
|
||||
class LLMInterface(abc.ABC, t.Generic[M, T]):
|
||||
'''This defines the loose contract for all openllm.LLM implementations.'''
|
||||
@property
|
||||
@@ -241,23 +247,31 @@ class LLMInterface(abc.ABC, t.Generic[M, T]):
|
||||
**attrs: t.Any
|
||||
) -> None:
|
||||
'''Generated __attrs_init__ for openllm.LLM.'''
|
||||
|
||||
_R = t.TypeVar('_R', covariant=True)
|
||||
|
||||
class _import_model_wrapper(t.Generic[_R, M, T], t.Protocol):
|
||||
def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R:
|
||||
...
|
||||
|
||||
class _load_model_wrapper(t.Generic[M, T], t.Protocol):
|
||||
def __call__(self, llm: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
...
|
||||
|
||||
class _load_tokenizer_wrapper(t.Generic[M, T], t.Protocol):
|
||||
def __call__(self, llm: LLM[M, T], **attrs: t.Any) -> T:
|
||||
...
|
||||
|
||||
class _llm_post_init_wrapper(t.Generic[M, T], t.Protocol):
|
||||
def __call__(self, llm: LLM[M, T]) -> T:
|
||||
...
|
||||
|
||||
class _save_pretrained_wrapper(t.Generic[M, T], t.Protocol):
|
||||
def __call__(self, llm: LLM[M, T], save_directory: str | pathlib.Path, **attrs: t.Any) -> None:
|
||||
...
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
# NOTE: the following wrapper are a light meta ops for wrapping default params to internal methods implementation.
|
||||
def _wrapped_import_model(f: _import_model_wrapper[bentoml.Model, M, T]) -> t.Callable[[LLM[M, T]], bentoml.Model]:
|
||||
@functools.wraps(f)
|
||||
@@ -269,11 +283,14 @@ def _wrapped_import_model(f: _import_model_wrapper[bentoml.Model, M, T]) -> t.Ca
|
||||
return f(self, *decls, trust_remote_code=trust_remote_code, **attrs)
|
||||
|
||||
return wrapper
|
||||
|
||||
_DEFAULT_TOKENIZER = 'hf-internal-testing/llama-tokenizer'
|
||||
|
||||
def get_engine_args(llm: LLM[M, T], tokenizer: str = _DEFAULT_TOKENIZER) -> vllm.EngineArgs:
|
||||
return vllm.EngineArgs(
|
||||
model=llm._bentomodel.path, tokenizer=tokenizer, tokenizer_mode='auto', tensor_parallel_size=1 if device_count() < 2 else device_count(), dtype='auto', worker_use_ray=False
|
||||
)
|
||||
|
||||
def _wrapped_load_model(f: _load_model_wrapper[M, T]) -> t.Callable[[LLM[M, T]], M | vllm.LLMEngine]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine:
|
||||
@@ -289,12 +306,14 @@ def _wrapped_load_model(f: _load_model_wrapper[M, T]) -> t.Callable[[LLM[M, T]],
|
||||
return f(self, *(*model_decls, *decls), **{**model_attrs, **attrs})
|
||||
|
||||
return wrapper
|
||||
|
||||
def _wrapped_load_tokenizer(f: _load_tokenizer_wrapper[M, T]) -> t.Callable[[LLM[M, T]], T]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(self: LLM[M, T], **tokenizer_attrs: t.Any) -> T:
|
||||
return f(self, **{**self.llm_parameters[-1], **tokenizer_attrs})
|
||||
|
||||
return wrapper
|
||||
|
||||
def _wrapped_llm_post_init(f: _llm_post_init_wrapper[M, T]) -> t.Callable[[LLM[M, T]], None]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(self: LLM[M, T]) -> None:
|
||||
@@ -302,6 +321,7 @@ def _wrapped_llm_post_init(f: _llm_post_init_wrapper[M, T]) -> t.Callable[[LLM[M
|
||||
f(self)
|
||||
|
||||
return wrapper
|
||||
|
||||
def _wrapped_save_pretrained(f: _save_pretrained_wrapper[M, T]) -> t.Callable[[LLM[M, T], str | pathlib.Path], None]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(self: LLM[M, T], save_directory: str | pathlib.Path, **attrs: t.Any) -> None:
|
||||
@@ -312,6 +332,7 @@ def _wrapped_save_pretrained(f: _save_pretrained_wrapper[M, T]) -> t.Callable[[L
|
||||
f(self, save_directory, **attrs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def _update_docstring(cls: LLM[M, T], fn: str) -> AnyCallable:
|
||||
# update docstring for given entrypoint
|
||||
original_fn = getattr(cls, fn, getattr(LLMInterface, fn))
|
||||
@@ -323,6 +344,7 @@ def _update_docstring(cls: LLM[M, T], fn: str) -> AnyCallable:
|
||||
'''
|
||||
setattr(cls, fn, original_fn)
|
||||
return original_fn
|
||||
|
||||
def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]], None]:
|
||||
attributes = {
|
||||
'import_model': _wrapped_import_model,
|
||||
@@ -361,8 +383,10 @@ def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]
|
||||
lines.extend([_setattr_class(key, f"cls.{fn} is not _cached_LLMInterface_get('{fn}')"), f"__gen_docstring(cls, '{fn}')",])
|
||||
anns[key] = interface_anns.get(key)
|
||||
return codegen.generate_function(cls, '__assign_llm_attr', lines, args=('cls', *args), globs=globs, annotations=anns)
|
||||
|
||||
def vllm_postprocess_generate(self: LLM['vllm.LLMEngine', T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str:
|
||||
return generation_result[0]['outputs'][0]['text']
|
||||
|
||||
def vllm_generate_iterator(
|
||||
self: LLM['vllm.LLMEngine', T], prompt: str, /, *, echo: bool = False, stop: str | t.Iterable[str] | None = None, stop_token_ids: list[int] | None = None, **attrs: t.Any
|
||||
) -> t.Iterator[dict[str, t.Any]]:
|
||||
@@ -387,6 +411,7 @@ def vllm_generate_iterator(
|
||||
else: text_outputs = [output.text for output in request_output.outputs]
|
||||
yield {'text': text_outputs, 'error_code': 0}
|
||||
if request_output.finished: break
|
||||
|
||||
def vllm_generate(self: LLM['vllm.LLMEngine', T], prompt: str, **attrs: t.Any) -> list[dict[str, t.Any]]:
|
||||
request_id: str = attrs.pop('request_id', None)
|
||||
if request_id is None: raise ValueError('request_id must not be None.')
|
||||
@@ -396,7 +421,9 @@ def vllm_generate(self: LLM['vllm.LLMEngine', T], prompt: str, **attrs: t.Any) -
|
||||
while self.model.has_unfinished_requests():
|
||||
outputs.extend([r for r in self.model.step() if r.finished])
|
||||
return [unmarshal_vllm_outputs(i) for i in outputs]
|
||||
|
||||
_AdaptersTuple: type[AdaptersTuple] = codegen.make_attr_tuple_class('AdaptersTuple', ['adapter_id', 'name', 'config'])
|
||||
|
||||
@attr.define(slots=True, repr=False, init=False)
|
||||
class LLM(LLMInterface[M, T], ReprMixin):
|
||||
if t.TYPE_CHECKING: __name__: str
|
||||
@@ -1140,6 +1167,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
del past_key_values, out
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# fmt: off
|
||||
@overload
|
||||
def Runner(model_name: str, *, model_id: str | None = None, model_version: str | None = ..., init_local: t.Literal[False, True] = ..., **attrs: t.Any) -> LLMRunner[t.Any, t.Any]: ...
|
||||
|
||||
@@ -14,12 +14,15 @@ autogptq, torch, transformers = LazyLoader('autogptq', globals(), 'auto_gptq'),
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QuantiseMode = t.Literal['int8', 'int4', 'gptq']
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['int8', 'int4'], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['gptq'], **attrs: t.Any) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop('llm_int8_threshhold', 6.0)
|
||||
|
||||
@@ -38,17 +38,20 @@ runners: list[AbstractRunner] = [runner]
|
||||
if not runner.supports_embeddings: runners.append(generic_embedding_runner)
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=runners)
|
||||
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None})
|
||||
|
||||
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, **{'adapter_name': qa_inputs.adapter_name, **config})
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
|
||||
@svc.api(route='/v1/generate_stream', input=_JsonInput, output=bentoml.io.Text(content_type='text/event-stream'))
|
||||
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
||||
echo = input_dict.pop('echo', False)
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, **qa_inputs.llm_config.model_dump())
|
||||
|
||||
@svc.api(
|
||||
route='/v1/metadata',
|
||||
input=bentoml.io.Text(),
|
||||
@@ -72,6 +75,7 @@ def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
supports_embeddings=runner.supports_embeddings,
|
||||
supports_hf_agent=runner.supports_hf_agent
|
||||
)
|
||||
|
||||
@svc.api(
|
||||
route='/v1/embeddings',
|
||||
input=bentoml.io.JSON.from_sample(['Hey Jude, welcome to the jungle!', 'What is the meaning of life?']),
|
||||
@@ -111,6 +115,7 @@ async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
|
||||
embed_call: _EmbeddingMethod = runner.embeddings if runner.supports_embeddings else generic_embedding_runner.encode # type: ignore[type-arg,assignment,valid-type]
|
||||
responses = (await embed_call.async_run(phrases))[0]
|
||||
return openllm.EmbeddingsOutput(embeddings=responses['embeddings'], num_tokens=responses['num_tokens'])
|
||||
|
||||
if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
|
||||
|
||||
async def hf_agent(request: Request) -> Response:
|
||||
@@ -127,11 +132,13 @@ if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
|
||||
|
||||
hf_app = Starlette(debug=True, routes=[Route('/agent', hf_agent, methods=['POST'])])
|
||||
svc.mount_asgi_app(hf_app, path='/hf')
|
||||
|
||||
# general metadata app
|
||||
async def list_adapter_v1(_: Request) -> Response:
|
||||
res: dict[str, t.Any] = {}
|
||||
if runner.peft_adapters['success'] is True: res['result'] = {k: v.to_dict() for k, v in runner.peft_adapters['result'].items()}
|
||||
res.update({'success': runner.peft_adapters['success'], 'error_msg': runner.peft_adapters['error_msg']})
|
||||
return JSONResponse(res, status_code=200)
|
||||
|
||||
adapters_app_v1 = Starlette(debug=True, routes=[Route('/adapters', list_adapter_v1, methods=['GET'])])
|
||||
svc.mount_asgi_app(adapters_app_v1, path='/v1')
|
||||
|
||||
@@ -30,6 +30,7 @@ if t.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD'
|
||||
|
||||
def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'openllm_client'] = 'openllm') -> str | None:
|
||||
'''Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set.'''
|
||||
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != 'true': return None
|
||||
@@ -48,6 +49,7 @@ def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'ope
|
||||
env.install(builder.build_system_requires)
|
||||
return builder.build('wheel', path, config_settings={'--global-option': '--quiet'})
|
||||
raise RuntimeError('Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.')
|
||||
|
||||
def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str | None] | None = None,) -> PythonOptions:
|
||||
packages = ['openllm', 'scipy'] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ['openllm[fine-tune]']
|
||||
@@ -100,6 +102,7 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d
|
||||
]
|
||||
if all(i for i in built_wheels): wheels.extend([llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in t.cast(t.List[str], built_wheels)])
|
||||
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=['https://download.pytorch.org/whl/cu118'])
|
||||
|
||||
def construct_docker_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
_: FS,
|
||||
@@ -137,8 +140,10 @@ def construct_docker_options(
|
||||
if _env['quantize_value'] is not None: env_dict[_env.quantize] = t.cast(str, _env['quantize_value'])
|
||||
env_dict[_env.runtime] = _env['runtime_value']
|
||||
return DockerOptions(base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}', env=env_dict, dockerfile_template=dockerfile_template)
|
||||
|
||||
OPENLLM_MODEL_NAME = '# openllm: model name'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: LiteralString = '__model_name__'
|
||||
|
||||
@@ -156,11 +161,15 @@ class ModelNameFormatter(string.Formatter):
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_id__'
|
||||
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_adapter_map__'
|
||||
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent / '_service.py'
|
||||
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
|
||||
from openllm_core.utils import DEBUG
|
||||
model_name = llm.config['model_name']
|
||||
@@ -174,6 +183,7 @@ def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | N
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
if DEBUG: logger.info('Generated script:\n%s', script)
|
||||
llm_fs.writetext(llm.config['service_name'], script)
|
||||
|
||||
@inject
|
||||
def create_bento(
|
||||
bento_tag: bentoml.Tag,
|
||||
|
||||
@@ -40,17 +40,23 @@ _OWNER = 'bentoml'
|
||||
_REPO = 'openllm'
|
||||
|
||||
_module_location = openllm_core.utils.pkg.source_locations('openllm')
|
||||
|
||||
@functools.lru_cache
|
||||
@openllm_core.utils.apply(str.lower)
|
||||
def get_base_container_name(reg: LiteralContainerRegistry) -> str:
|
||||
return _CONTAINER_REGISTRY[reg]
|
||||
|
||||
def _convert_version_from_string(s: str) -> VersionInfo:
|
||||
return VersionInfo.from_version_string(s)
|
||||
|
||||
def _commit_time_range(r: int = 5) -> str:
|
||||
return (datetime.now(timezone.utc) - timedelta(days=r)).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
class VersionNotSupported(openllm.exceptions.OpenLLMException):
|
||||
"""Raised when the stable release is too low that it doesn't include OpenLLM base container."""
|
||||
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class('_RefTuple', ['git_hash', 'version', 'strategy'])
|
||||
|
||||
def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
# NOTE: all openllm container will have sha-<git_hash[:7]>
|
||||
# This will use docker to run skopeo to determine the correct latest tag that is available
|
||||
@@ -64,6 +70,7 @@ def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
return next(f'sha-{it["sha"][:7]}' for it in commits if '[skip ci]' not in it['commit']['message'])
|
||||
# now is the correct behaviour
|
||||
return orjson.loads(subprocess.check_output([docker_bin, 'run', '--rm', '-it', 'quay.io/skopeo/stable:latest', 'list-tags', 'docker://ghcr.io/bentoml/openllm']).decode().strip())['Tags'][-2]
|
||||
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
|
||||
class RefResolver:
|
||||
git_hash: str = attr.field()
|
||||
@@ -108,9 +115,11 @@ class RefResolver:
|
||||
if self.strategy == 'latest': return 'latest'
|
||||
elif self.strategy == 'nightly': return self.git_hash
|
||||
else: return repr(self.version)
|
||||
|
||||
@functools.lru_cache(maxsize=256)
|
||||
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str:
|
||||
return RefResolver.from_strategy(strategy).tag
|
||||
|
||||
def build_container(
|
||||
registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None,
|
||||
version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
@@ -146,13 +155,16 @@ def build_container(
|
||||
except Exception as err:
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}') from err
|
||||
return tags
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
|
||||
supported_registries: list[str]
|
||||
|
||||
__all__ = ['CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries', 'RefResolver']
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == 'supported_registries': return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
|
||||
elif name == 'CONTAINER_NAMES': return _CONTAINER_REGISTRY
|
||||
|
||||
@@ -31,10 +31,13 @@ LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(str(it.tag), help='Bento') for it in bentoml.list() if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})]
|
||||
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(inflection.dasherize(it), help='Model') for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
|
||||
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
@@ -55,7 +58,9 @@ def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_res
|
||||
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
|
||||
if DEBUG: logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
|
||||
return environ
|
||||
|
||||
_adapter_mapping_key = 'adapter_map'
|
||||
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None:
|
||||
if not value: return None
|
||||
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
|
||||
@@ -69,6 +74,7 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
pass
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
return None
|
||||
|
||||
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
|
||||
llm_config = openllm.AutoConfig.for_model(model)
|
||||
command_attrs: DictStrAny = dict(
|
||||
@@ -212,6 +218,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
return config
|
||||
|
||||
return start_cmd
|
||||
|
||||
def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, **command_attrs: t.Any) -> click.Command:
|
||||
context_settings = command_attrs.pop('context_settings', {})
|
||||
context_settings.update({'ignore_unknown_options': True, 'allow_extra_args': True})
|
||||
@@ -224,6 +231,7 @@ def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, *
|
||||
return llm_config
|
||||
|
||||
return noop
|
||||
|
||||
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
|
||||
if adapter_map and not openllm.utils.is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantize and llm_config.default_implementation() == 'vllm':
|
||||
@@ -232,6 +240,7 @@ def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: Lite
|
||||
if requirements is not None and len(requirements) > 0:
|
||||
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
|
||||
if len(missing_requirements) > 0: termui.echo(f'Make sure to have the following dependencies available: {missing_requirements}', fg='yellow')
|
||||
|
||||
def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
|
||||
composed = openllm.utils.compose(
|
||||
@@ -301,6 +310,7 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
return composed(fn)
|
||||
|
||||
return wrapper
|
||||
|
||||
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
|
||||
if value is None: return value
|
||||
if not isinstance(value, tuple): ctx.fail(f'{param} only accept multiple values, not {type(value)} (value: {value})')
|
||||
@@ -308,10 +318,12 @@ def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tup
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all': return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
# NOTE: A list of bentoml option that is not needed for parsing.
|
||||
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
|
||||
# NOTE: production is also deprecated
|
||||
_IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
|
||||
|
||||
def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
|
||||
'''Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`.'''
|
||||
from bentoml_cli.cli import cli
|
||||
@@ -339,7 +351,9 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
|
||||
return group(f)
|
||||
|
||||
return decorator
|
||||
|
||||
_http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True)
|
||||
|
||||
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
'''General ``@click`` decorator with some sauce.
|
||||
|
||||
@@ -356,8 +370,10 @@ def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC |
|
||||
return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs))
|
||||
|
||||
return decorator
|
||||
|
||||
cli_option = functools.partial(_click_factory_type, attr='option')
|
||||
cli_argument = functools.partial(_click_factory_type, attr='argument')
|
||||
|
||||
def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = 'pretty', **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
output = ['json', 'pretty', 'porcelain']
|
||||
|
||||
@@ -377,6 +393,7 @@ def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput
|
||||
shell_complete=complete_output_var,
|
||||
**attrs
|
||||
)(f)
|
||||
|
||||
def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--fast/--no-fast',
|
||||
@@ -390,10 +407,13 @@ def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC
|
||||
''',
|
||||
**attrs
|
||||
)(f)
|
||||
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs)(f)
|
||||
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--model-id',
|
||||
@@ -404,10 +424,13 @@ def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.E
|
||||
help='Optional model_id name or path for (fine-tune) weight.',
|
||||
**attrs
|
||||
)(f)
|
||||
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f)
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--quantise',
|
||||
@@ -433,6 +456,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.''',
|
||||
**attrs
|
||||
)(f)
|
||||
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--workers-per-resource',
|
||||
@@ -458,6 +482,7 @@ def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool =
|
||||
),
|
||||
**attrs
|
||||
)(f)
|
||||
|
||||
def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--bettertransformer',
|
||||
@@ -469,6 +494,7 @@ def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = Fal
|
||||
if not build else 'Set default environment variable whether to serve this model with FasterTransformer in build time.',
|
||||
**attrs
|
||||
)(f)
|
||||
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--serialisation',
|
||||
@@ -498,6 +524,7 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
''',
|
||||
**attrs
|
||||
)(f)
|
||||
|
||||
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--container-registry',
|
||||
@@ -517,7 +544,9 @@ def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) ->
|
||||
''',
|
||||
**attrs
|
||||
)(f)
|
||||
|
||||
_wpr_strategies = {'round_robin', 'conserved'}
|
||||
|
||||
def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
value = inflection.underscore(value)
|
||||
@@ -529,6 +558,7 @@ def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, va
|
||||
raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
|
||||
else:
|
||||
return value
|
||||
|
||||
def container_registry_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
if value not in openllm.bundle.supported_registries: raise click.BadParameter(f'Value must be one of {openllm.bundle.supported_registries}', ctx, param)
|
||||
|
||||
@@ -22,6 +22,7 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy, LiteralRuntime, LiteralString
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _start(
|
||||
model_name: str,
|
||||
/,
|
||||
@@ -108,6 +109,7 @@ def _start(
|
||||
return start_command_factory(start_command if not _serve_grpc else start_grpc_command, model_name, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=_serve_grpc).main(
|
||||
args=args if len(args) > 0 else None, standalone_mode=False
|
||||
)
|
||||
|
||||
@inject
|
||||
def _build(
|
||||
model_name: str,
|
||||
@@ -213,6 +215,7 @@ def _build(
|
||||
if matched is None:
|
||||
raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
|
||||
return bentoml.get(matched.group(1), _bento_store=bento_store)
|
||||
|
||||
def _import_model(
|
||||
model_name: str,
|
||||
/,
|
||||
@@ -262,6 +265,7 @@ def _import_model(
|
||||
if additional_args is not None: args.extend(additional_args)
|
||||
if quantize is not None: args.extend(['--quantize', quantize])
|
||||
return import_command.main(args=args, standalone_mode=False)
|
||||
|
||||
def _list_models() -> dict[str, t.Any]:
|
||||
'''List all available models within the local store.'''
|
||||
from .entrypoint import models_command
|
||||
|
||||
@@ -8,6 +8,7 @@ import openllm
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import container_registry_option, machine_option
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
|
||||
@click.command(
|
||||
'build_base_container',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
|
||||
@@ -12,6 +12,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar, machine_option
|
||||
if t.TYPE_CHECKING: from bentoml._internal.bento import BentoStore
|
||||
|
||||
@click.command('dive_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@machine_option
|
||||
|
||||
@@ -13,6 +13,7 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar
|
||||
from openllm_core.utils import bentoml_cattr
|
||||
if t.TYPE_CHECKING: from bentoml._internal.bento import BentoStore
|
||||
|
||||
@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.')
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
|
||||
@@ -11,6 +11,7 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import machine_option, model_complete_envvar, output_option
|
||||
from openllm_core._prompt import process_prompt
|
||||
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), shell_complete=model_complete_envvar)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
|
||||
@@ -9,6 +9,7 @@ import openllm
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import LiteralOutput, output_option
|
||||
|
||||
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@output_option(default_value='json')
|
||||
@click.pass_context
|
||||
|
||||
@@ -11,6 +11,7 @@ from bentoml._internal.utils import human_readable_size
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import LiteralOutput, model_complete_envvar, model_name_argument, output_option
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False, shell_complete=model_complete_envvar)
|
||||
@output_option(default_value='json')
|
||||
|
||||
@@ -20,11 +20,13 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_notebook_metadata() -> DictStrAny:
|
||||
with open(os.path.join(os.path.dirname(playground.__file__), '_meta.yml'), 'r') as f:
|
||||
content = yaml.safe_load(f)
|
||||
if not all('description' in k for k in content.values()): raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
|
||||
return content
|
||||
|
||||
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('output-dir', default=None, required=False)
|
||||
@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server')
|
||||
|
||||
@@ -7,9 +7,11 @@ import inflection
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
def echo(text: t.Any, fg: str = 'green', _with_style: bool = True, **attrs: t.Any) -> None:
|
||||
attrs['fg'] = fg if not openllm.utils.get_debug_mode() else None
|
||||
if not openllm.utils.get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(text, **attrs)
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore}
|
||||
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS']
|
||||
|
||||
@@ -15,7 +15,9 @@ import typing as t
|
||||
|
||||
import openllm_client
|
||||
if t.TYPE_CHECKING: from openllm_client import AsyncHTTPClient as AsyncHTTPClient, BaseAsyncClient as BaseAsyncClient, BaseClient as BaseClient, HTTPClient as HTTPClient, GrpcClient as GrpcClient, AsyncGrpcClient as AsyncGrpcClient
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(dir(openllm_client))
|
||||
|
||||
def __getattr__(it: str) -> t.Any:
|
||||
return getattr(openllm_client, it)
|
||||
|
||||
@@ -22,6 +22,7 @@ if t.TYPE_CHECKING:
|
||||
ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAutoLLMClass:
|
||||
_model_mapping: t.ClassVar[_LazyAutoMapping]
|
||||
|
||||
@@ -81,6 +82,7 @@ class BaseAutoLLMClass:
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class ({config_class}) for {name}. Model name should be one of {', '.join(openllm.CONFIG_MAPPING.keys())} (Registered configuration class: {', '.join([i.__name__ for i in cls._model_mapping.keys()])})."
|
||||
)
|
||||
|
||||
def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
|
||||
if attr is None: return
|
||||
if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr)
|
||||
@@ -93,6 +95,7 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
|
||||
except ValueError:
|
||||
raise ValueError(f'Could not find {attr} neither in {module} nor in {openllm_module}!') from None
|
||||
raise ValueError(f'Could not find {attr} in {openllm_module}!')
|
||||
|
||||
class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
|
||||
|
||||
@@ -168,4 +171,5 @@ class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
if hasattr(key, '__name__') and key.__name__ in self._reverse_config_mapping:
|
||||
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.")
|
||||
self._extra_content[key] = value
|
||||
|
||||
__all__ = ['BaseAutoLLMClass', '_LazyAutoMapping']
|
||||
|
||||
@@ -9,5 +9,6 @@ MODEL_MAPPING_NAMES = OrderedDict([('chatglm', 'ChatGLM'), ('dolly_v2', 'DollyV2
|
||||
'opt', 'OPT'
|
||||
), ('stablelm', 'StableLM'), ('starcoder', 'StarCoder'), ('baichuan', 'Baichuan')])
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
|
||||
class AutoLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_MAPPING
|
||||
|
||||
@@ -7,5 +7,6 @@ from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
MODEL_FLAX_MAPPING_NAMES = OrderedDict([('flan_t5', 'FlaxFlanT5'), ('opt', 'FlaxOPT')])
|
||||
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
|
||||
|
||||
class AutoFlaxLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_FLAX_MAPPING
|
||||
|
||||
@@ -7,5 +7,6 @@ from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
MODEL_TF_MAPPING_NAMES = OrderedDict([('flan_t5', 'TFFlanT5'), ('opt', 'TFOPT')])
|
||||
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
|
||||
|
||||
class AutoTFLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_TF_MAPPING
|
||||
|
||||
@@ -9,5 +9,6 @@ MODEL_VLLM_MAPPING_NAMES = OrderedDict([('baichuan', 'VLLMBaichuan'), ('dolly_v2
|
||||
'opt', 'VLLMOPT'
|
||||
), ('stablelm', 'VLLMStableLM'), ('starcoder', 'VLLMStarCoder'), ('llama', 'VLLMLlama')])
|
||||
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
|
||||
|
||||
class AutoVLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_VLLM_MAPPING
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class Baichuan(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMBaichuan(openllm.LLM['vllm.LLMEngine', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class ChatGLM(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -9,12 +9,15 @@ from openllm_core.config.configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE,
|
||||
if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf
|
||||
else: torch, transformers, tf = openllm.utils.LazyLoader('torch', globals(), 'torch'), openllm.utils.LazyLoader('transformers', globals(), 'transformers'), openllm.utils.LazyLoader('tf', globals(), 'tensorflow')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]:
|
||||
...
|
||||
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
|
||||
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
|
||||
class InstructionTextGenerationPipeline(transformers.Pipeline):
|
||||
@@ -115,6 +118,7 @@ def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.Pr
|
||||
return records
|
||||
|
||||
return InstructionTextGenerationPipeline() if _init else InstructionTextGenerationPipeline
|
||||
|
||||
class DollyV2(openllm.LLM['transformers.Pipeline', 'transformers.PreTrainedTokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VLLMDollyV2(openllm.LLM['vllm.LLMEngine', 'transformers.PreTrainedTokenizer']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -4,6 +4,7 @@ import typing as t
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader('torch', globals(), 'torch'), openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
|
||||
class Falcon(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VLLMFalcon(openllm.LLM['vllm.LLMEngine', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class FlanT5(openllm.LLM['transformers.T5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import openllm
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.config.configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class FlaxFlanT5(openllm.LLM['transformers.FlaxT5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class TFFlanT5(openllm.LLM['transformers.TFT5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GPTNeoX(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMGPTNeoX(openllm.LLM['vllm.LLMEngine', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -3,5 +3,6 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMLlama(openllm.LLM['vllm.LLMEngine', 'transformers.LlamaTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@@ -8,6 +8,7 @@ from openllm.utils import generate_labels, is_triton_available
|
||||
if t.TYPE_CHECKING: import transformers, torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_mpt_config(
|
||||
model_id_or_path: str, max_sequence_length: int, device: torch.device | str | int | None, device_map: str | None = None, trust_remote_code: bool = True
|
||||
) -> transformers.PretrainedConfig:
|
||||
@@ -22,6 +23,7 @@ def get_mpt_config(
|
||||
# setting max_seq_len
|
||||
config.max_seq_len = max_sequence_length
|
||||
return config
|
||||
|
||||
class MPT(openllm.LLM['transformers.PreTrainedModel', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers, vllm
|
||||
|
||||
class VLLMMPT(openllm.LLM['vllm.LLMEngine', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -11,6 +11,7 @@ if t.TYPE_CHECKING: import transformers
|
||||
else: transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FlaxOPT(openllm.LLM['transformers.TFOPTForCausalLM', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OPT(openllm.LLM['transformers.OPTForCausalLM', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import bentoml
|
||||
import openllm
|
||||
from openllm_core.utils import generate_labels
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class TFOPT(openllm.LLM['transformers.TFOPTForCausalLM', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import openllm
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMOPT(openllm.LLM['vllm.LLMEngine', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class StableLM(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMStableLM(openllm.LLM['vllm.LLMEngine', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -7,6 +7,7 @@ import openllm
|
||||
from openllm.utils import generate_labels
|
||||
from openllm_core.config.configuration_starcoder import EOD, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers.GPT2TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMStarCoder(openllm.LLM['vllm.LLMEngine', 'transformers.GPT2TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -24,6 +24,7 @@ from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
DEFAULT_MODEL_ID = "ybelkada/falcon-7b-sharded-bf16"
|
||||
DATASET_NAME = "timdettmers/openassistant-guanaco"
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
@@ -40,10 +41,12 @@ class TrainingArguments:
|
||||
group_by_length: bool = dataclasses.field(default=True)
|
||||
lr_scheduler_type: str = dataclasses.field(default="constant")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "falcon"))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
max_sequence_length: int = dataclasses.field(default=512)
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
|
||||
@@ -12,6 +12,7 @@ MAX_NEW_TOKENS = 384
|
||||
|
||||
Q = "Answer the following question, step by step:\n{q}\nA:"
|
||||
question = "What is the meaning of life?"
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("question", default=question)
|
||||
@@ -42,9 +43,11 @@ def main() -> int:
|
||||
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
|
||||
|
||||
return 0
|
||||
|
||||
def _mp_fn(index: t.Any): # noqa # type: ignore
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
main()
|
||||
else:
|
||||
|
||||
@@ -29,6 +29,7 @@ from random import randint, randrange
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from datasets import load_dataset
|
||||
|
||||
# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
|
||||
def find_all_linear_names(model):
|
||||
lora_module_names = set()
|
||||
@@ -40,11 +41,13 @@ def find_all_linear_names(model):
|
||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove("lm_head")
|
||||
return list(lora_module_names)
|
||||
|
||||
# Change this to the local converted path if you don't have access to the meta-llama model
|
||||
DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf"
|
||||
# change this to 'main' if you want to use the latest llama
|
||||
DEFAULT_MODEL_VERSION = "335a02887eb6684d487240bbc28b5699298c3135"
|
||||
DATASET_NAME = "databricks/databricks-dolly-15k"
|
||||
|
||||
def format_dolly(sample):
|
||||
instruction = f"### Instruction\n{sample['instruction']}"
|
||||
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
|
||||
@@ -52,12 +55,15 @@ def format_dolly(sample):
|
||||
# join all the parts together
|
||||
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
|
||||
return prompt
|
||||
|
||||
# template dataset to add prompt to each sample
|
||||
def template_dataset(sample, tokenizer):
|
||||
sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
|
||||
return sample
|
||||
|
||||
# empty list to save remainder from batches to use in next batch
|
||||
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}
|
||||
|
||||
def chunk(sample, chunk_length=2048):
|
||||
# define global remainder variable to save remainder from batches to use in next batch
|
||||
global remainder
|
||||
@@ -78,6 +84,7 @@ def chunk(sample, chunk_length=2048):
|
||||
# prepare labels
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
|
||||
# Load dataset from the hub
|
||||
dataset = load_dataset(dataset_name, split="train")
|
||||
@@ -96,6 +103,7 @@ def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
|
||||
# Print total number of samples
|
||||
print(f"Total number of samples: {len(lm_dataset)}")
|
||||
return lm_dataset
|
||||
|
||||
def prepare_for_int4_training(model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True,
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
@@ -130,6 +138,7 @@ def prepare_for_int4_training(model_id: str, model_version: str | None = None, g
|
||||
if bf16 and module.weight.dtype == torch.float32:
|
||||
module = module.to(torch.bfloat16)
|
||||
return model, tokenizer
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=1)
|
||||
@@ -141,12 +150,14 @@ class TrainingArguments:
|
||||
report_to: str = dataclasses.field(default="none")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "llama"))
|
||||
save_strategy: str = dataclasses.field(default="no")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION)
|
||||
seed: int = dataclasses.field(default=42)
|
||||
merge_weights: bool = dataclasses.field(default=False)
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
model_args, training_rags = ModelArguments(), TrainingArguments()
|
||||
else:
|
||||
@@ -160,6 +171,7 @@ else:
|
||||
|
||||
# import the model first hand
|
||||
openllm.import_model("llama", model_id=model_args.model_id, model_version=model_args.model_version)
|
||||
|
||||
def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
import peft
|
||||
|
||||
@@ -194,4 +206,5 @@ def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
model.save_pretrained(os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB")
|
||||
else:
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
|
||||
train_loop(model_args, training_args)
|
||||
|
||||
@@ -24,6 +24,7 @@ from datasets import load_dataset
|
||||
if t.TYPE_CHECKING:
|
||||
from peft import PeftModel
|
||||
DEFAULT_MODEL_ID = "facebook/opt-6.7b"
|
||||
|
||||
def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments):
|
||||
return transformers.Trainer(
|
||||
model=model,
|
||||
@@ -31,6 +32,7 @@ def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, da
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
@@ -41,9 +43,11 @@ class TrainingArguments:
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "opt"))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
|
||||
@@ -37,6 +37,7 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from . import constants as constants, ggml as ggml, transformers as transformers
|
||||
P = ParamSpec('P')
|
||||
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
'''Load the tokenizer from BentoML store.
|
||||
|
||||
@@ -66,10 +67,13 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
elif tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
return tokenizer
|
||||
|
||||
class _Caller(t.Protocol[P]):
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
...
|
||||
|
||||
_extras = ['get', 'import_model', 'save_pretrained', 'load_model']
|
||||
|
||||
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
@@ -81,6 +85,7 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
return getattr(importlib.import_module(f'.{llm.runtime}', __name__), fn)(llm, *args, **kwargs)
|
||||
|
||||
return caller
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model:
|
||||
@@ -94,10 +99,13 @@ if t.TYPE_CHECKING:
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M:
|
||||
...
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []}
|
||||
__all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == 'load_tokenizer': return load_tokenizer
|
||||
elif name in _import_structure: return importlib.import_module(f'.{name}', __name__)
|
||||
|
||||
@@ -10,8 +10,10 @@ import openllm
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import M
|
||||
|
||||
_conversion_strategy = {'pt': 'ggml'}
|
||||
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
|
||||
'''Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
@@ -31,7 +33,9 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
|
||||
if auto_import:
|
||||
return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise
|
||||
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
@@ -34,6 +34,7 @@ else:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['import_model', 'get', 'load_model', 'save_pretrained']
|
||||
|
||||
@inject
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
|
||||
"""Auto detect model type from given model_id and import it to bentoml's model store.
|
||||
@@ -136,6 +137,7 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
# in the case where users first run openllm start without the model available locally.
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
return bentomodel
|
||||
|
||||
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
'''Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
@@ -157,6 +159,7 @@ def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise err from None
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
'''Load the model from BentoML store.
|
||||
|
||||
@@ -189,6 +192,7 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
if llm.bettertransformer and isinstance(model, transformers.PreTrainedModel): model = model.to_bettertransformer()
|
||||
if llm.__llm_implementation__ in {'pt', 'vllm'}: check_unintialised_params(model)
|
||||
return t.cast('M', model)
|
||||
|
||||
def save_pretrained(
|
||||
llm: openllm.LLM[M, T],
|
||||
save_directory: str,
|
||||
|
||||
@@ -18,6 +18,7 @@ else:
|
||||
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
@@ -37,10 +38,12 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
|
||||
if __cls is None: raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
|
||||
return __cls
|
||||
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
if llm.config['trust_remote_code']:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
@@ -55,9 +58,11 @@ def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.Pretra
|
||||
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
|
||||
else: raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx])
|
||||
|
||||
def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
def update_model(bentomodel: bentoml.Model, metadata: DictStrAny) -> bentoml.Model:
|
||||
based: DictStrAny = copy.deepcopy(bentomodel.info.metadata)
|
||||
based.update(metadata)
|
||||
@@ -65,6 +70,7 @@ def update_model(bentomodel: bentoml.Model, metadata: DictStrAny) -> bentoml.Mod
|
||||
tag=bentomodel.info.tag, module=bentomodel.info.module, labels=bentomodel.info.labels, options=bentomodel.info.options.to_dict(), signatures=bentomodel.info.signatures, context=bentomodel.info.context, api_version=bentomodel.info.api_version, creation_time=bentomodel.info.creation_time, metadata=based
|
||||
))
|
||||
return bentomodel
|
||||
|
||||
# NOTE: sync with bentoml/_internal/frameworks/transformers.py#make_default_signatures
|
||||
def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
|
||||
infer_fn: tuple[str, ...] = ('__call__',)
|
||||
|
||||
@@ -6,8 +6,10 @@ from huggingface_hub import HfApi
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
return any(s.rfilename.endswith('.safetensors') for s in HfApi().model_info(model_id, revision=revision).siblings)
|
||||
|
||||
@attr.define(slots=True)
|
||||
class HfIgnore:
|
||||
safetensors = '*.safetensors'
|
||||
|
||||
@@ -11,6 +11,7 @@ import openllm
|
||||
if t.TYPE_CHECKING: from ._typing_compat import LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_bento(
|
||||
model: str, model_id: str | None = None, quantize: t.Literal['int4', 'int8', 'gptq'] | None = None, runtime: t.Literal['ggml', 'transformers'] = 'transformers', cleanup: bool = False
|
||||
@@ -21,6 +22,7 @@ def build_bento(
|
||||
if cleanup:
|
||||
logger.info('Deleting %s', bento.tag)
|
||||
bentoml.bentos.delete(bento.tag)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
|
||||
@@ -36,6 +38,7 @@ def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | N
|
||||
if cleanup:
|
||||
logger.info('Deleting container %s', image_tag)
|
||||
subprocess.check_output([executable, 'rmi', '-f', image_tag])
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prepare(
|
||||
model: str,
|
||||
|
||||
@@ -8,17 +8,14 @@ import typing as t
|
||||
|
||||
import openllm_core
|
||||
|
||||
from . import (
|
||||
dummy_flax_objects as dummy_flax_objects,
|
||||
dummy_pt_objects as dummy_pt_objects,
|
||||
dummy_tf_objects as dummy_tf_objects,
|
||||
dummy_vllm_objects as dummy_vllm_objects,
|
||||
)
|
||||
from . import dummy_flax_objects as dummy_flax_objects, dummy_pt_objects as dummy_pt_objects, dummy_tf_objects as dummy_tf_objects, dummy_vllm_objects as dummy_vllm_objects
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
from openllm_core._typing_compat import LiteralRuntime
|
||||
|
||||
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]:
|
||||
return {'runtime': llm.runtime, 'framework': 'openllm', 'model_name': llm.config['model_name'], 'architecture': llm.config['architecture'], 'serialisation_format': llm._serialisation_format}
|
||||
|
||||
def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | openllm.AutoTFLLM | openllm.AutoFlaxLLM | openllm.AutoVLLM]:
|
||||
import openllm
|
||||
if implementation == 'tf': return openllm.AutoTFLLM
|
||||
@@ -26,9 +23,12 @@ def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | o
|
||||
elif implementation == 'pt': return openllm.AutoLLM
|
||||
elif implementation == 'vllm': return openllm.AutoVLLM
|
||||
else: raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')")
|
||||
|
||||
__all__ = ['generate_labels', 'infer_auto_class', 'dummy_flax_objects', 'dummy_pt_objects', 'dummy_tf_objects', 'dummy_vllm_objects']
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
def __getattr__(it: str) -> t.Any:
|
||||
if hasattr(openllm_core.utils, it): return getattr(openllm_core.utils, it)
|
||||
else: raise AttributeError(f'module {__name__} has no attribute {it}')
|
||||
|
||||
Reference in New Issue
Block a user