From 0a37bac25dbdfa725929c46c0fb5d3f40229807d Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Sat, 27 May 2023 04:46:54 -0700 Subject: [PATCH] feat(codegen): using black parser (#5) --- pyproject.toml | 13 +-- src/openllm/_service.py | 13 ++- src/openllm/utils/codegen.py | 163 +++++++++++++++++++++++++---------- 3 files changed, 132 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 76307079..37351562 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,14 +62,14 @@ dependencies = [ "inflection", # pydantic 2 ftw "pydantic", - # astor for generating service file. - "astor", + # black for generating service file. + "black[jupyter]==23.3.0", ] [project.urls] -Documentation = "https://github.com/bentoml/open-llm-server#readme" -Issues = "https://github.com/bentoml/open-llm-server/issues" -Source = "https://github.com/bentoml/open-llm-server" +Documentation = "https://github.com/llmsys/openllm#readme" +Issues = "https://github.com/llmsys/openllm/issues" +Source = "https://github.com/llmsys/openllm" [project.scripts] openllm = "openllm.cli:cli" @@ -102,7 +102,8 @@ python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "pyright"] +dependencies = ["ruff>=0.0.243", "pyright"] + [tool.hatch.envs.lint.scripts] typing = "pyright {args:src/openllm tests}" style = ["ruff {args:.}", "black --check --diff {args:.}"] diff --git a/src/openllm/_service.py b/src/openllm/_service.py index d3b9f2f2..8385cd90 100644 --- a/src/openllm/_service.py +++ b/src/openllm/_service.py @@ -1,3 +1,12 @@ +""" +The service definition for running any LLMService. + +Note that the line `model = ...` is a special line and should not be modified. This will be handled by openllm +internally to generate the correct model service when bundling the LLM to a Bento. +This will ensure that 'bentoml serve llm-bento' will work accordingly. + +The generation code lives under ./utils/codegen.py +""" from __future__ import annotations import os @@ -7,10 +16,6 @@ import bentoml import openllm -# NOTE: The below code should not be changed as it will be used by the ast parser -# to generate the service code. This is the current drawback of this approach, but -# good for now. The below make sure that users who use `bentoml serve llm-bento` would -# work. model = os.environ.get("OPENLLM_MODEL", "{__model_name__}") # openllm: model name llm_config = openllm.AutoConfig.for_model(model) diff --git a/src/openllm/utils/codegen.py b/src/openllm/utils/codegen.py index a1ad22a3..62c12ebf 100644 --- a/src/openllm/utils/codegen.py +++ b/src/openllm/utils/codegen.py @@ -14,68 +14,137 @@ from __future__ import annotations -import ast import logging +import string import typing as t from pathlib import Path -import astor +from black import decode_bytes, detect_target_versions, get_future_imports +from black.comments import list_comments, normalize_fmt_off +from black.linegen import LineGenerator, transform_line +from black.lines import EmptyLineTracker +from black.mode import Feature, Mode, TargetVersion, supports_feature +from black.nodes import syms +from black.parsing import lib2to3_parse +from blib2to3.pgen2 import token +from blib2to3.pytree import Leaf, Node if t.TYPE_CHECKING: + from black.lines import LinesBlock from fs.base import FS logger = logging.getLogger(__name__) +OPENLLM_MODEL_NAME = {"# openllm: model name"} + + +class ModelFormatter(string.Formatter): + model_keyword: t.LiteralString = "__model_name__" + + def __init__(self, model_name: str): + super().__init__() + self.model_name = model_name + + def vformat(self, format_string: str) -> str: + return super().vformat(format_string, (), {self.model_keyword: self.model_name}) + + def can_format(self, value: str) -> bool: + try: + self.parse(value) + return True + except ValueError: + return False + + def is_correct_leaf(self, leaf: Leaf): + return leaf.type == token.STRING and self.can_format(leaf.value) + + +def recurse_modify_node(node: Node | Leaf, node_type: int, model_name: str) -> Node | None: + if isinstance(node, Node) and node.type == node_type: + modify_node_with_comments(node, model_name) + for child in node.children: + recurse_modify_node(child, node_type, model_name) + + +def modify_node_with_comments(node: Node, model_name: str): + """ + Modify the node with comments '# openllm: model name' and replace + the formatted value with the actual model name. + """ + _formatter = ModelFormatter(model_name) + for children in node.children: + if isinstance(children, Leaf) and _formatter.is_correct_leaf(children): + children.value = _formatter.vformat(children.value) + + +_service_file = Path(__file__).parent.parent / "_service.py" + def write_service(model_name: str, target_path: str, llm_fs: FS): logger.debug("Generating service for %s to %s", model_name, target_path) - service_file = Path(__file__).parent.parent / "_service.py" - with open(service_file.__fspath__(), "rb") as f: - node = ast.parse(f.read()) - generator = ServiceGenerator(model_name) - generator.visit(node) - llm_fs.writetext( - target_path, f"# GENERATED BY 'openllm bundle {model_name}'. DO NOT EDIT\n" + "".join(generator.result) - ) + + mode = Mode(target_versions={TargetVersion.PY311}, is_pyi=False) + with open(_service_file.__fspath__(), "r") as f: + src_contents = f.read() + + dst_contents = _parse_service_file(src_contents, model_name, mode) + # Forced second pass to work around optional trailing commas (becoming + # forced trailing commas on pass 2) interacting differently with optional + # parentheses. Admittedly ugly. + if src_contents != dst_contents: + dst_contents = _parse_service_file(dst_contents, model_name, mode) + + llm_fs.writetext(target_path, f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n" + dst_contents) -class ServiceGenerator(astor.SourceGenerator): - def __init__(self, model_name: str, indent_width: str | None = None): - self.model_name = model_name - if indent_width is None: - indent_width = " " * 4 - super().__init__(indent_width) +def _parse_service_file(src_contents: str, model_name: str, mode: Mode) -> str: + """This function is an extension of black.format_str, + where we can modify the AST before formatting.""" - def visit_Assign(self, node: ast.Assign): - self.newline(node) + src_node = lib2to3_parse(src_contents, mode.target_versions) - # we need to handle the value assignment for model name, "{__model_name__}" - # The first iteration is heuristic, as we will loop through all of the call args, which could be very slow. - # only parse os.environ.get - try: - if ( - isinstance(node.value, ast.Call) - and isinstance(node.value.func, ast.Attribute) - and node.value.func.attr == "get" - ): - if isinstance(node.value.func.value, ast.Attribute) and node.value.func.value.attr == "environ": - if isinstance(node.value.func.value.value, ast.Name) and node.value.func.value.value.id == "os": - # right now, the last arg is the default value - arg = node.value.args[-1] - if not isinstance(arg, ast.Constant): - pass - else: - string_value = arg.value.format(__model_name__=self.model_name) - node.value.args[-1] = ast.Constant(value=string_value) - else: - pass - except AttributeError as err: - logger.error(f"Error parsing os.environ.get: {err}") + # This is the actual AST handling + for leaf in src_node.leaves(): + for comment in list_comments(leaf.prefix, is_endmarker=False): + if comment.value in OPENLLM_MODEL_NAME: + assert leaf.prev_sibling is not None, "'# openllm: model name' line must not be modified." + recurse_modify_node(leaf.prev_sibling, syms.arglist, model_name) - # Finally, actually write the assignment - for idx, target in enumerate(node.targets): - if idx: - self.write(", ") - self.visit(target) - self.write(" = ") - self.visit(node.value) + # NOTE: The below is the same as black.format_str + dst_blocks: list[LinesBlock] = [] + if mode.target_versions: + versions = mode.target_versions + else: + future_imports = get_future_imports(src_node) + versions = detect_target_versions(src_node, future_imports=future_imports) + + context_manager_features = { + feature for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS} if supports_feature(versions, feature) + } + normalize_fmt_off(src_node) + lines = LineGenerator(mode=mode, features=context_manager_features) + elt = EmptyLineTracker(mode=mode) + split_line_features = { + feature + for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} + if supports_feature(versions, feature) + } + block: LinesBlock | None = None + for current_line in lines.visit(src_node): + block = elt.maybe_empty_lines(current_line) + dst_blocks.append(block) + for line in transform_line(current_line, mode=mode, features=split_line_features): + block.content_lines.append(str(line)) + if dst_blocks: + dst_blocks[-1].after = 0 + dst_contents: list[str] = [] + for block in dst_blocks: + dst_contents.extend(block.all_lines()) + if not dst_contents: + # Use decode_bytes to retrieve the correct source newline (CRLF or LF), + # and check if normalized_content has more than one line + normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8")) + if "\n" in normalized_content: + return newline + return "" + return "".join(dst_contents)