Files
OpenLLM/tools/update-models-import.py
2023-07-31 07:54:26 -04:00

50 lines
2.5 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import typing as t
from pathlib import Path
import openllm
_TARGET_FILE = Path(__file__).parent.parent / "src" / "openllm" / "models" / "__init__.py"
def comment_generator(comment_type: str, action: t.Literal["start", "stop"] = "start", indentation: int = 0) -> str:
return " "*indentation + f"# {os.path.basename(__file__)}: {action} {comment_type}\n"
START_MODULE_COMMENT, STOP_MODULE_COMMENT = comment_generator("module"), comment_generator("module", "stop")
START_TYPES_COMMENT, STOP_TYPES_COMMENT = comment_generator("types", indentation=2), comment_generator("types", "stop", indentation=2)
@openllm.utils.apply(lambda v: sorted([" "*2 + _ for _ in v], key=lambda k: k.split()[-1]))
def create_stubs_import() -> list[str]:
return [f"from . import {p.name} as {p.name}\n" for p in _TARGET_FILE.parent.glob("*/") if p.name not in {"__pycache__", "__init__.py", ".DS_Store"}]
def create_module_import() -> str:
return f"_MODELS: set[str] = {{{', '.join(sorted([repr(p.name) for p in _TARGET_FILE.parent.glob('*/') if p.name not in ['__pycache__', '__init__.py', '.DS_Store']]))}}}\n"
def main() -> int:
with _TARGET_FILE.open("r") as f:
processed = f.readlines()
stubs_lines, module_line = create_stubs_import(), create_module_import()
start_module_idx, stop_module_idx = processed.index(START_MODULE_COMMENT), processed.index(STOP_MODULE_COMMENT)
start_types_idx, stop_types_idex = processed.index(START_TYPES_COMMENT), processed.index(STOP_TYPES_COMMENT)
processed = processed[:start_module_idx] + [START_MODULE_COMMENT, module_line, STOP_MODULE_COMMENT] + processed[stop_module_idx + 1:start_types_idx] + [START_TYPES_COMMENT, *stubs_lines, STOP_TYPES_COMMENT] + processed[stop_types_idex + 1:]
with _TARGET_FILE.open("w") as f:
f.writelines(processed)
return 0
if __name__ == "__main__":
raise SystemExit(main())