Files
OpenLLM/tools/update-dummy.py
2023-09-01 17:00:49 +00:00

66 lines
3.1 KiB
Python
Executable File

#!/usr/bin/env python3
from __future__ import annotations
import os, typing as t, sys
from pathlib import Path
_ROOT = Path(__file__).parent.parent
sys.path.insert(0, (_ROOT / 'openllm-core' / 'src').__fspath__())
sys.path.insert(1, (_ROOT / 'openllm-python' / 'src').__fspath__())
from openllm_core._typing_compat import LiteralBackend
from openllm.models import auto
from openllm import CONFIG_MAPPING
if t.TYPE_CHECKING: from collections import OrderedDict
config_requirements = {k: [_.replace('-', '_') for _ in v.__openllm_requirements__] if v.__openllm_requirements__ else None for k, v in CONFIG_MAPPING.items()}
_dependencies: dict[LiteralBackend, str] = {k: v for k, v in zip(LiteralBackend.__args__[:-2], ('torch', 'tensorflow', 'flax', 'vllm'))}
_auto: dict[str, str] = {k: v for k, v in zip(LiteralBackend.__args__[:-2], ('AutoLLM', 'AutoTFLLM', 'AutoFlaxLLM', 'AutoVLLM'))}
def get_target_dummy_file(backend: LiteralBackend) -> Path:
return _ROOT / 'openllm-python' / 'src' / 'openllm' / 'utils' / f'dummy_{backend}_objects.py'
def mapping_names(backend: LiteralBackend):
return 'MODEL_MAPPING_NAMES' if backend == 'pt' else f'MODEL_{backend.upper()}_MAPPING_NAMES'
def get_mapping(backend: LiteralBackend) -> OrderedDict[t.Any, t.Any]:
return getattr(auto, mapping_names(backend))
def make_class_stub(model_name: str, backend: LiteralBackend, indentation: int = 2, auto: bool = False) -> list[str]:
_dep_list: list[str] = [
f'"{v}"' for v in [_dependencies[backend], *(t.cast(t.List[str], config_requirements[model_name]) if model_name != '__default__' and config_requirements[model_name] else [])]
]
if auto: cl_ = _auto[backend]
else: cl_ = get_mapping(backend)[model_name]
lines = [
f'class {cl_}(metaclass=_DummyMetaclass):',
' ' * indentation + f"_backends=[{','.join(_dep_list)}]",
' ' * indentation + f"def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,[{','.join(_dep_list)}])"
]
return lines
def write_stub(backend: LiteralBackend, _path: str) -> list[str]:
base = [
f'# This file is generated by {_path}. DO NOT EDIT MANUALLY!',
f'# To update this, run ./{_path}',
'from __future__ import annotations',
'import typing as _t',
'from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends',
]
base.extend([v for it in [make_class_stub(k, backend) for k in get_mapping(backend)] for v in it])
# autoclass
base.extend(make_class_stub('__default__', backend, auto=True))
# mapping and export
_imports = [f'"{v}"' for v in get_mapping(backend).values()]
base += [f'{mapping_names(backend)}:_t.Any=None', f"__all__:list[str]=[\"{mapping_names(backend)}\",\"{_auto[backend]}\",{','.join(_imports)}]\n"]
return base
def main() -> int:
_path = os.path.join(os.path.basename(os.path.dirname(__file__)), os.path.basename(__file__))
for backend in _dependencies:
with get_target_dummy_file(backend).open('w') as f:
f.write('\n'.join(write_stub(backend, _path)))
return 0
if __name__ == '__main__': raise SystemExit(main())