# Copyright 2023 BentoML Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import logging import string import typing as t from pathlib import Path from black import decode_bytes, detect_target_versions, get_future_imports from black.comments import list_comments, normalize_fmt_off from black.linegen import LineGenerator, transform_line from black.lines import EmptyLineTracker from black.mode import Feature, Mode, TargetVersion, supports_feature from black.nodes import syms from black.parsing import lib2to3_parse from blib2to3.pgen2 import token from blib2to3.pytree import Leaf, Node if t.TYPE_CHECKING: from black.lines import LinesBlock from fs.base import FS class ModifyNodeProtocol(t.Protocol): @t.overload def __call__(self, node: Node, model_name: str) -> None: ... @t.overload def __call__(self, node: Node, *args: t.Any, **attrs: t.Any) -> None: ... logger = logging.getLogger(__name__) OPENLLM_MODEL_NAME = {"# openllm: model name"} class ModelFormatter(string.Formatter): model_keyword: t.LiteralString = "__model_name__" def __init__(self, model_name: str): super().__init__() self.model_name = model_name def vformat(self, format_string: str) -> str: return super().vformat(format_string, (), {self.model_keyword: self.model_name}) def can_format(self, value: str) -> bool: try: self.parse(value) return True except ValueError: return False def is_correct_leaf(self, leaf: Leaf): return leaf.type == token.STRING and self.can_format(leaf.value) def recurse_modify_node(node: Node | Leaf, node_type: int, callables: ModifyNodeProtocol, *args: t.Any) -> Node | None: if isinstance(node, Node) and node.type == node_type: callables(node, *args) for child in node.children: recurse_modify_node(child, node_type, callables, *args) def modify_node_with_comments(node: Node, model_name: str): """ Modify the node with comments '# openllm: model name' and replace the formatted value with the actual model name. """ _formatter = ModelFormatter(model_name) for children in node.children: if isinstance(children, Leaf) and _formatter.is_correct_leaf(children): children.value = _formatter.vformat(children.value) _service_file = Path(__file__).parent.parent / "_service.py" def write_service(model_name: str, target_path: str, llm_fs: FS): logger.debug("Generating service for %s to %s", model_name, target_path) mode = Mode(target_versions={TargetVersion.PY311}, is_pyi=False) with open(_service_file.__fspath__(), "r") as f: src_contents = f.read() dst_contents = _parse_service_file(src_contents, model_name, mode) # Forced second pass to work around optional trailing commas (becoming # forced trailing commas on pass 2) interacting differently with optional # parentheses. Admittedly ugly. if src_contents != dst_contents: dst_contents = _parse_service_file(dst_contents, model_name, mode) llm_fs.writetext(target_path, f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n" + dst_contents) def _parse_service_file(src_contents: str, model_name: str, mode: Mode) -> str: """This function is an extension of black.format_str, where we can modify the AST before formatting.""" src_node = lib2to3_parse(src_contents, mode.target_versions) # This is the actual AST handling for leaf in src_node.leaves(): for comment in list_comments(leaf.prefix, is_endmarker=False): if comment.value in OPENLLM_MODEL_NAME: assert leaf.prev_sibling is not None, "'# openllm: model name' line must not be modified." recurse_modify_node(leaf.prev_sibling, syms.arglist, modify_node_with_comments, model_name) # NOTE: The below is the same as black.format_str dst_blocks: list[LinesBlock] = [] if mode.target_versions: versions = mode.target_versions else: future_imports = get_future_imports(src_node) versions = detect_target_versions(src_node, future_imports=future_imports) context_manager_features = { feature for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS} if supports_feature(versions, feature) } normalize_fmt_off(src_node) lines = LineGenerator(mode=mode, features=context_manager_features) elt = EmptyLineTracker(mode=mode) split_line_features = { feature for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} if supports_feature(versions, feature) } block: LinesBlock | None = None for current_line in lines.visit(src_node): block = elt.maybe_empty_lines(current_line) dst_blocks.append(block) for line in transform_line(current_line, mode=mode, features=split_line_features): block.content_lines.append(str(line)) if dst_blocks: dst_blocks[-1].after = 0 dst_contents: list[str] = [] for block in dst_blocks: dst_contents.extend(block.all_lines()) if not dst_contents: # Use decode_bytes to retrieve the correct source newline (CRLF or LF), # and check if normalized_content has more than one line normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8")) if "\n" in normalized_content: return newline return "" return "".join(dst_contents)