From 8c2867d26dfff8a4cf33bc59d5a8dee159f3256a Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Mon, 31 Jul 2023 07:54:26 -0400 Subject: [PATCH] style: define experimental guidelines (#168) --- .pre-commit-config.yaml | 6 + DEVELOPMENT.md | 9 +- README.md | 2 + STYLE.md | 160 ++ changelog.d/168.chore.md | 3 + examples/bentoml-demo/service.py | 8 +- examples/langchain-chains-demo/service.py | 35 +- examples/langchain-tools-demo/service.py | 9 +- pyproject.toml | 44 + src/openllm/__init__.py | 369 ++- src/openllm/__main__.py | 4 +- src/openllm/_configuration.py | 2545 +++++++---------- src/openllm/_generation.py | 30 +- src/openllm/_llm.py | 1829 ++++++------ src/openllm/_prompt.py | 44 +- src/openllm/_quantisation.py | 155 +- src/openllm/_schema.py | 132 +- src/openllm/_service.py | 78 +- src/openllm/_strategies.py | 600 ++-- src/openllm/_types.py | 134 +- src/openllm/bundle/__init__.py | 29 +- src/openllm/bundle/_package.py | 287 +- src/openllm/bundle/oci/__init__.py | 86 +- src/openllm/cli/__init__.py | 1 - src/openllm/cli/_factory.py | 539 ++-- src/openllm/cli/entrypoint.py | 1631 +++++------ src/openllm/cli/ext/__init__.py | 1 - src/openllm/cli/ext/build_base_container.py | 18 +- src/openllm/cli/ext/dive_bentos.py | 23 +- src/openllm/cli/ext/get_containerfile.py | 38 +- src/openllm/cli/ext/get_prompt.py | 52 +- src/openllm/cli/ext/list_bentos.py | 12 +- src/openllm/cli/ext/list_models.py | 18 +- src/openllm/cli/termui.py | 6 +- src/openllm/client.py | 28 +- src/openllm/exceptions.py | 21 +- src/openllm/models/__init__.py | 35 +- src/openllm/models/auto/__init__.py | 98 +- src/openllm/models/auto/configuration_auto.py | 139 +- src/openllm/models/auto/factory.py | 198 +- src/openllm/models/auto/modeling_auto.py | 20 +- src/openllm/models/auto/modeling_flax_auto.py | 11 +- src/openllm/models/auto/modeling_tf_auto.py | 11 +- src/openllm/models/auto/modeling_vllm_auto.py | 11 +- src/openllm/models/baichuan/__init__.py | 28 +- .../models/baichuan/configuration_baichuan.py | 35 +- .../models/baichuan/modeling_baichuan.py | 22 +- src/openllm/models/chatglm/__init__.py | 28 +- .../models/chatglm/configuration_chatglm.py | 45 +- .../models/chatglm/modeling_chatglm.py | 75 +- src/openllm/models/dolly_v2/__init__.py | 28 +- .../models/dolly_v2/configuration_dolly_v2.py | 38 +- .../models/dolly_v2/modeling_dolly_v2.py | 224 +- src/openllm/models/falcon/__init__.py | 28 +- .../models/falcon/configuration_falcon.py | 47 +- src/openllm/models/falcon/modeling_falcon.py | 44 +- src/openllm/models/flan_t5/__init__.py | 64 +- .../models/flan_t5/configuration_flan_t5.py | 31 +- .../models/flan_t5/modeling_flan_t5.py | 40 +- .../models/flan_t5/modeling_flax_flan_t5.py | 23 +- .../models/flan_t5/modeling_tf_flan_t5.py | 15 +- src/openllm/models/gpt_neox/__init__.py | 28 +- .../models/gpt_neox/configuration_gpt_neox.py | 23 +- .../models/gpt_neox/modeling_gpt_neox.py | 32 +- src/openllm/models/llama/__init__.py | 48 +- .../models/llama/configuration_llama.py | 71 +- src/openllm/models/llama/modeling_llama.py | 42 +- .../models/llama/modeling_vllm_llama.py | 10 +- src/openllm/models/mpt/__init__.py | 30 +- src/openllm/models/mpt/configuration_mpt.py | 44 +- src/openllm/models/mpt/modeling_mpt.py | 117 +- src/openllm/models/opt/__init__.py | 82 +- src/openllm/models/opt/configuration_opt.py | 46 +- src/openllm/models/opt/modeling_flax_opt.py | 36 +- src/openllm/models/opt/modeling_opt.py | 45 +- src/openllm/models/opt/modeling_tf_opt.py | 35 +- src/openllm/models/opt/modeling_vllm_opt.py | 12 +- src/openllm/models/stablelm/__init__.py | 28 +- .../models/stablelm/configuration_stablelm.py | 28 +- .../models/stablelm/modeling_stablelm.py | 37 +- src/openllm/models/starcoder/__init__.py | 28 +- .../starcoder/configuration_starcoder.py | 32 +- .../models/starcoder/modeling_starcoder.py | 94 +- src/openllm/playground/falcon_tuned.py | 85 +- src/openllm/playground/features.py | 56 +- src/openllm/playground/llama2_qlora.py | 271 +- src/openllm/playground/opt_tuned.py | 67 +- src/openllm/serialisation/__init__.py | 130 +- src/openllm/serialisation/constants.py | 17 +- src/openllm/serialisation/ggml.py | 64 +- src/openllm/serialisation/transformers.py | 356 +-- src/openllm/testing.py | 103 +- src/openllm/utils/__init__.py | 478 ++-- src/openllm/utils/analytics.py | 112 +- src/openllm/utils/codegen.py | 366 +-- src/openllm/utils/dantic.py | 702 +++-- src/openllm/utils/dummy_flax_objects.py | 23 +- .../utils/dummy_pt_and_cpm_kernels_objects.py | 13 +- .../utils/dummy_pt_and_einops_objects.py | 6 +- .../utils/dummy_pt_and_triton_objects.py | 6 +- src/openllm/utils/dummy_pt_objects.py | 58 +- src/openllm/utils/dummy_tf_objects.py | 23 +- src/openllm/utils/dummy_vllm_objects.py | 24 +- src/openllm/utils/import_utils.py | 432 ++- src/openllm/utils/lazy.py | 306 +- src/openllm/utils/representation.py | 74 +- src/openllm_client/__init__.py | 1 - src/openllm_client/runtimes/base.py | 492 ++-- src/openllm_client/runtimes/grpc.py | 137 +- src/openllm_client/runtimes/http.py | 146 +- tests/__init__.py | 3 +- tests/_strategies/_configuration.py | 74 +- tests/client_test.py | 6 +- tests/configuration_test.py | 247 +- tests/conftest.py | 64 +- tests/models/conftest.py | 360 +-- tests/models/flan_t5_test.py | 38 +- tests/models/opt_test.py | 38 +- tests/models_test.py | 17 +- tests/package_test.py | 49 +- tests/strategies_test.py | 276 +- tools/assert-model-table-latest | 19 +- tools/dependencies.py | 330 +-- tools/generate-coverage.py | 67 +- tools/update-config-stubs.py | 178 +- tools/update-models-import.py | 33 +- tools/update-readme.py | 109 +- tools/write-coverage-report.py | 58 +- 128 files changed, 8314 insertions(+), 9472 deletions(-) create mode 100644 STYLE.md create mode 100644 changelog.d/168.chore.md diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4f67bb68..ed4dc4ee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,6 +32,12 @@ repos: types: [python] exclude: ^(docs|tools|tests) args: [--config=pyproject.toml] + - repo: https://github.com/google/yapf + rev: v0.40.1 + hooks: + - id: yapf + types: [python] + args: [--parallel, --recursive] - repo: local hooks: - id: mypy diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 8b0630bb..2bba6352 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -155,8 +155,13 @@ hatch run tests:snapshot-models ## Working with Git -To filter out most of the generated commits for infrastructure, use ``--invert-grep`` in conjunction with ``--grep`` -to filter out all commits with regex `"[generated]"` +To filter out most of the generated commits for infrastructure, use +`--invert-grep` in conjunction with `--grep` to filter out all commits with +regex `"[generated]"` + +## Style + +See [STYLE.md](STYLE.md) for our style guide. ## Releasing a New Version diff --git a/README.md b/README.md index 3c662a03..4c4d026c 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ python_version Hatch + + code style Ruff diff --git a/STYLE.md b/STYLE.md new file mode 100644 index 00000000..ad1afe95 --- /dev/null +++ b/STYLE.md @@ -0,0 +1,160 @@ +## the coding style + +This documentation serves as a brief discussion of the coding style used for +OpenLLM. As you have noticed, it is different from the conventional +[PEP8](https://peps.python.org/pep-0008/) style as across many Python projects. +The manifestation of OpenLLM code style is a combination of +[Google Python Style](https://google.github.io/styleguide/pyguide.html), +inspiration from coding language such as APL, Haskell, and is designed for fast, +experimental development and prototyping. + +Everyone always has their own opinions on style. I believe this is exemplified +further within the Python community, as it tries to be beginner-friendly, and +therefore most people hold a very strong opinion on styling. I don't have a +strong opinion on style either (I don't have any issue with PEP8, as we use it +for our other projects), as long as: + +- You don't use any linter, formatter that change the style drastically other + than what specified within the projects' [`pyproject.toml`](./pyproject.toml). +- The code you contribute is not widely different from the style of the code + surrounding it. + +With that being said, I want to use this project as a playground the explore a +style that is both: "feels natural" and expressive for mathematical reasoning. I +hope that you find this guide somewhat thought-provoking and interesting, that +you can iterate and try to adopt some of them as part of the process +contributing to the library. + +While PEP8 is a great base for a style guide, I find it to be having way too +much white spaces and makes the code feels 'robotic'. Having a deterministic +style and formatter is great to reduce the overhead of stylistic discussions, +but I think it is important to write code that express the intent of reasoning. +(_The policy here is definitely not "shovel everything into one line", but +rather "compact and flowing"_) + +The styling is heavily inspired by +[Kenneth Iverson's](https://en.wikipedia.org/wiki/Kenneth_E._Iverson) 1979 +Turing award lecture, +[Notation as a Tool of Thought](https://www.eecg.toronto.edu/~jzhu/csc326/readings/iverson.pdf), +and a lot of the stylistic inspiration comes from +[Jeremy Howard's](https://jeremy.fast.ai/) [fastai](https://docs.fast.ai/). One +thing that has been stuck with me ever since is the idea of "brevity facilitates +reasoning", as such the tersity of style aren't just for the sake of shortness, +rather the brevity of expression. (it enables +[expository programming](http://archive.vector.org.uk/art10000980), combining +with prototyping new ideas and logics within models implementation) + +## some guidelines + +Though I have stopped using deterministic formatter and linter, I do understand +that people have preferences for using these tools, and it plays nicely with IDE +and editors. As such, I included a [`pyproject.toml`](./pyproject.toml) file +that specifies some configuration for the tools that makes it compiliant with +the repository's style. In short, some of the tools include `ruff`, `yapf`, and +`interrogate`. Since we manage everything via `hatch`, refer back to the +[DEVELOPMENT.md](./DEVELOPMENT.md) for more information on this. + +Overtime, Python has incorporated a lot of features that supports this style of +coding, including list comprehension, generator expression, lambda, array-based +programming. Yet, Python will remain verbose per se, and the goal is that to +make code fit nicely on a screen, and we don't have to always scroll downwards. + +While brevity is important, it is also important to make sure functions are +somewhat, type-safe. Since there is no real type-safety when working with +Python, typing should be a best-effort to make sure we don't introduce too many +bugs. + +### naming + +- follow Python standard for this, I don't have too much opinion on this. Just + make sure that it is descriptive, and the abbreviation describes the intent of + the variable. i.e: `to_gpu` instead of `t_gpu`, `to_cpu` instead of `t_cpu`. +- any math-related notation or neural net layers should be expressive and stay + close to the paper as much as possible. For example: `lm_head.weight` instead + of `lm_head.w`. Espically for implementing custom kernels and layers, it is + crucial to follow its nomenclature. E.g: `conv1` instead of + `first_conv_layer`. + +_If you have any suggestions, feel free to give it on our discord server!_ + +### layout + +- Preferably not a lot of whitespaces, but rather flowing. If you can fit + everything for `if`, `def` or a `return` within one line, then there's no need + to break it into multiple line: + + ```python + def foo(x): return rotate_cv(x) if x > 0 else -x + ``` + +- imports should be grouped by their types, and each import should be designated + on its own line. + + ```python + import os + import sys + ``` + This is partially to make it easier to work with merge-conflicts, and easier + for IDE to navigate context definition. + +- indent with 2 spaces, which follow the Google codestyle. + +- With regards to writing operator, try to follow the domain-specific notation. + I.e: when writing pathlib, just don't add space since that is not how you + write a path in the terminal. `yapf` will try to accommodate some of this + changes. + +- Avoid trailing whitespace + +- use array, pytorch or numpy-based indexing where possible. + +- If you need to export anything, put it in `__all__` or do lazy export for + type-safe checker. + +### misc + +- import alias should be concise and descriptive. A convention is to always + `import typing as t`. +- Writing docstring when it is possible. No need to comment everything asn it + makes the codebase hard to read. For docstring, follow the Google style guide. +- We do lazy imports, so consult some of the `__init__.py` to see how we do it. +- Documentation is still _working-in-progress_, but tldr it will be written in + MDX and will be hosted on the GitHub Pages, so stay tuned! +- If anything that is not used for runtime, just put it under `t.TYPE_CHECKING` + +### note on codegen + +- We also do some codegen for some of the assignment functions. These logics are + largely based on the work of [attrs](https://github.com/python-attrs/attrs) to + ensure fast and isolated codegen in Python. If you need codegen but don't know + how it works, feel free to mention @aarnphm_ on discord! + +## FAQ + +### Why not use `black`? + +`black` is used on our other projects, but I rather find `black` to be very +verbose and overtime it is annoying to work with too much whitespaces. + +### Why not PEP8? + +PEP8 is great if you are writing library such as this, but I'm going to do a lot +of experimenting for implementing papers, so I decided early on that PEP8 is +probably not fit here, and want to explore more expressive style. + +### Editor is complaining about the style, what should I do? + +Kindly ask you to disable linting for this project 🤗. I will try my best to +accomodate with ruff and yapf, but I don't want to spend too much time on this. +It is pretty stragithforward to disable it in your editor, with google. + +### Style might put off new contributors? + +I don't think so, as mentioned before, I don't have too much opinion on style as +long as it somewhat follow what I have described above or the style of the code +surrounding it. I will still accept styles PR as long as it is not too drastic. +Just make sure to add the revision to `.git-blame-ignore-revs` so that +`git blame` would work correctly. + +As for people who are too close-minded about styling, such individuals aren't +the ones we want to work with anyway! diff --git a/changelog.d/168.chore.md b/changelog.d/168.chore.md new file mode 100644 index 00000000..bf91e3be --- /dev/null +++ b/changelog.d/168.chore.md @@ -0,0 +1,3 @@ +Define specific style guideline for the project. See +[STYLE.md](https://github.com/bentoml/OpenLLM/blob/main/STYLE.md) for more +information. diff --git a/examples/bentoml-demo/service.py b/examples/bentoml-demo/service.py index 80c53ec9..e6c6c7e1 100644 --- a/examples/bentoml-demo/service.py +++ b/examples/bentoml-demo/service.py @@ -24,13 +24,11 @@ llm_runner = openllm.Runner(model, llm_config=llm_config) svc = bentoml.Service(name="llm-service", runners=[llm_runner]) - @svc.on_startup def download(_: bentoml.Context): - llm_runner.download_model() - + llm_runner.download_model() @svc.api(input=bentoml.io.Text(), output=bentoml.io.Text()) async def prompt(input_text: str) -> str: - answer = await llm_runner.generate.async_run(input_text) - return answer[0]["generated_text"] + answer = await llm_runner.generate.async_run(input_text) + return answer[0]["generated_text"] diff --git a/examples/langchain-chains-demo/service.py b/examples/langchain-chains-demo/service.py index a05bcca1..52abe7ce 100644 --- a/examples/langchain-chains-demo/service.py +++ b/examples/langchain-chains-demo/service.py @@ -25,23 +25,20 @@ from bentoml.io import JSON from bentoml.io import Text class Query(BaseModel): - industry: str - product_name: str - keywords: t.List[str] - llm_config: t.Dict[str, t.Any] - + industry: str + product_name: str + keywords: t.List[str] + llm_config: t.Dict[str, t.Any] def gen_llm(model_name: str, model_id: str | None = None) -> OpenLLM: - lc_llm = OpenLLM(model_name=model_name, model_id=model_id, embedded=False) - lc_llm.runner.download_model() - return lc_llm - + lc_llm = OpenLLM(model_name=model_name, model_id=model_id, embedded=False) + lc_llm.runner.download_model() + return lc_llm llm = gen_llm("dolly-v2", model_id="databricks/dolly-v2-7b") prompt = PromptTemplate( - input_variables=["industry", "product_name", "keywords"], - template=""" + input_variables=["industry", "product_name", "keywords"], template=""" You are a Facebook Ads Copywriter with a strong background in persuasive writing and marketing. You craft compelling copy that appeals to the target audience's emotions and needs, peruading them to take action or make a @@ -59,22 +56,12 @@ chain = LLMChain(llm=llm, prompt=prompt) svc = bentoml.Service("fb-ads-copy", runners=[llm.runner]) - @svc.on_startup def download(_: bentoml.Context): - llm.runner.download_model() - - -SAMPLE_INPUT = Query( - industry="SAAS", - product_name="BentoML", - keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"], - llm_config=llm.runner.config.model_dump(), -) + llm.runner.download_model() +SAMPLE_INPUT = Query(industry="SAAS", product_name="BentoML", keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"], llm_config=llm.runner.config.model_dump(),) @svc.api(input=JSON.from_sample(sample=SAMPLE_INPUT), output=Text()) def generate(query: Query): - return chain.run( - {"industry": query.industry, "product_name": query.product_name, "keywords": ", ".join(query.keywords)} - ) + return chain.run({"industry": query.industry, "product_name": query.product_name, "keywords": ", ".join(query.keywords)}) diff --git a/examples/langchain-tools-demo/service.py b/examples/langchain-tools-demo/service.py index b919c6e6..d67700c0 100644 --- a/examples/langchain-tools-demo/service.py +++ b/examples/langchain-tools-demo/service.py @@ -22,16 +22,11 @@ from bentoml.io import Text SAMPLE_INPUT = "What is the weather in San Francisco?" -llm = OpenLLM( - model_name="dolly-v2", - model_id="databricks/dolly-v2-7b", - embedded=False, -) +llm = OpenLLM(model_name="dolly-v2", model_id="databricks/dolly-v2-7b", embedded=False,) tools = load_tools(["serpapi"], llm=llm) agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION) svc = bentoml.Service("langchain-openllm", runners=[llm.runner]) - @svc.api(input=Text.from_sample(sample=SAMPLE_INPUT), output=Text()) def chat(input_text: str): - return agent.run(input_text) + return agent.run(input_text) diff --git a/pyproject.toml b/pyproject.toml index 0ca01fb0..b1d84d86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -205,6 +205,7 @@ ignore = [ "PLR0915", "PLR2004", # magic value to use constant "E501", # ignore line length violation + "E702", "PYI021", # ignore docstring in stubs, as pyright will include docstring in stubs. "D103", # Just missing docstring for magic methods. "D102", @@ -262,6 +263,49 @@ avoid-escape = false ] "typings/**/*" = ["D", "F", "E", "PYI002"] +[tool.yapf] +ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = true +ALLOW_MULTILINE_DICTIONARY_KEYS = false +ALLOW_MULTILINE_LAMBDAS = false +ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS = false +ALLOW_SPLIT_BEFORE_DICT_VALUE = false +ARITHMETIC_PRECEDENCE_INDICATION = true +BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION = 1 +BLANK_LINES_BETWEEN_TOP_LEVEL_IMPORTS_AND_VARIABLES = 1 +BLANK_LINE_BEFORE_CLASS_DOCSTRING = false +BLANK_LINE_BEFORE_MODULE_DOCSTRING = false +BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = false +COALESCE_BRACKETS = true +COLUMN_LIMIT = 384 +CONTINUATION_ALIGN_STYLE = "VALIGN-RIGHT" +DEDENT_CLOSING_BRACKETS = true +DISABLE_ENDING_COMMA_HEURISTIC = true +EACH_DICT_ENTRY_ON_SEPARATE_LINE = false +INDENT_BLANK_LINES = false +INDENT_CLOSING_BRACKETS = false +INDENT_WIDTH = 2 +JOIN_MULTIPLE_LINES = true +NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = true +SPACES_AROUND_SUBSCRIPT_COLON = false +SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = false +SPACE_INSIDE_BRACKETS = false +SPLIT_ALL_COMMA_SEPARATED_VALUES = false +SPLIT_ALL_TOP_LEVEL_COMMA_SEPARATED_VALUES = false +SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED = false +SPLIT_BEFORE_BITWISE_OPERATOR = false +SPLIT_BEFORE_CLOSING_BRACKET = false +SPLIT_BEFORE_DICT_SET_GENERATOR = false +SPLIT_BEFORE_DOT = false +SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = false +SPLIT_BEFORE_FIRST_ARGUMENT = false +SPLIT_BEFORE_LOGICAL_OPERATOR = false +SPLIT_BEFORE_NAMED_ASSIGNS = false +SPLIT_COMPLEX_COMPREHENSION = false +SPLIT_PENALTY_AFTER_OPENING_BRACKET = 10000 +SPLIT_PENALTY_BEFORE_IF_EXPR = 10000 +SPLIT_PENALTY_COMPREHENSION = 3000 +SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT = 8000 + [tool.coverage.paths] openllm = ["src/openllm", "*/openllm/src/openllm"] [tool.coverage.run] diff --git a/src/openllm/__init__.py b/src/openllm/__init__.py index 28f68911..77dab8f4 100644 --- a/src/openllm/__init__.py +++ b/src/openllm/__init__.py @@ -32,244 +32,217 @@ from . import utils as utils from .exceptions import MissingDependencyError if utils.DEBUG: - utils.set_debug_mode(True) - utils.set_quiet_mode(False) - logging.basicConfig(level=logging.NOTSET) + utils.set_debug_mode(True) + utils.set_quiet_mode(False) + logging.basicConfig(level=logging.NOTSET) else: - # configuration for bitsandbytes before import - os.environ["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1") - # The following warnings from bitsandbytes, and probably not that important - # for users to see when DEBUG is False - warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization") - warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization") - warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.") - + # configuration for bitsandbytes before import + os.environ["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1") + # The following warnings from bitsandbytes, and probably not that important + # for users to see when DEBUG is False + warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization") + warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization") + warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.") _import_structure: dict[str, list[str]] = { - "_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"], - "_configuration": ["LLMConfig"], - "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs"], - "_generation": ["StopSequenceCriteria", "StopOnTokens"], - "_quantisation": ["infer_quantisation_config"], - "exceptions": [], - "utils": ["infer_auto_class"], - "models": [], - "client": [], - "bundle": [], - "playground": [], - "testing": [], - "serialisation": ["ggml", "transformers"], - "cli.entrypoint": ["start", "start_grpc", "build", "import_model", "list_models"], + "_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"], "_configuration": ["LLMConfig"], "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs"], "_generation": ["StopSequenceCriteria", "StopOnTokens"], "_quantisation": ["infer_quantisation_config"], "exceptions": [], "utils": ["infer_auto_class"], "models": [], + "client": [], "bundle": [], "playground": [], "testing": [], "serialisation": ["ggml", "transformers"], "cli.entrypoint": ["start", "start_grpc", "build", "import_model", "list_models"], # NOTE: models - "models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES", ], - "models.chatglm": ["ChatGLMConfig"], - "models.baichuan": ["BaichuanConfig"], - "models.dolly_v2": ["DollyV2Config"], - "models.falcon": ["FalconConfig"], - "models.flan_t5": ["FlanT5Config"], - "models.gpt_neox": ["GPTNeoXConfig"], - "models.llama": ["LlamaConfig"], - "models.mpt": ["MPTConfig"], - "models.opt": ["OPTConfig"], - "models.stablelm": ["StableLMConfig"], - "models.starcoder": ["StarCoderConfig"], + "models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"], "models.chatglm": ["ChatGLMConfig"], "models.baichuan": ["BaichuanConfig"], "models.dolly_v2": ["DollyV2Config"], "models.falcon": ["FalconConfig"], "models.flan_t5": ["FlanT5Config"], "models.gpt_neox": ["GPTNeoXConfig"], + "models.llama": ["LlamaConfig"], "models.mpt": ["MPTConfig"], "models.opt": ["OPTConfig"], "models.stablelm": ["StableLMConfig"], "models.starcoder": ["StarCoderConfig"], } # NOTE: torch and cpm_kernels try: - if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError + if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError except MissingDependencyError: - from .utils import dummy_pt_and_cpm_kernels_objects - _import_structure["utils.dummy_pt_and_cpm_kernels_objects"] = [name for name in dir(dummy_pt_and_cpm_kernels_objects) if not name.startswith("_")] + from .utils import dummy_pt_and_cpm_kernels_objects + _import_structure["utils.dummy_pt_and_cpm_kernels_objects"] = [name for name in dir(dummy_pt_and_cpm_kernels_objects) if not name.startswith("_")] else: - _import_structure["models.chatglm"].extend(["ChatGLM"]) - _import_structure["models.baichuan"].extend(["Baichuan"]) + _import_structure["models.chatglm"].extend(["ChatGLM"]) + _import_structure["models.baichuan"].extend(["Baichuan"]) try: - if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError + if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError except MissingDependencyError: - from .utils import dummy_pt_and_einops_objects - _import_structure["utils.dummy_pt_and_einops_objects"] = [name for name in dir(dummy_pt_and_einops_objects) if not name.startswith("_")] + from .utils import dummy_pt_and_einops_objects + _import_structure["utils.dummy_pt_and_einops_objects"] = [name for name in dir(dummy_pt_and_einops_objects) if not name.startswith("_")] else: - _import_structure["models.falcon"].extend(["Falcon"]) + _import_structure["models.falcon"].extend(["Falcon"]) try: - if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError + if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError except MissingDependencyError: - from .utils import dummy_pt_and_triton_objects - _import_structure["utils.dummy_pt_and_triton_objects"] = [name for name in dir(dummy_pt_and_triton_objects) if not name.startswith("_")] + from .utils import dummy_pt_and_triton_objects + _import_structure["utils.dummy_pt_and_triton_objects"] = [name for name in dir(dummy_pt_and_triton_objects) if not name.startswith("_")] else: - _import_structure["models.mpt"].extend(["MPT"]) + _import_structure["models.mpt"].extend(["MPT"]) try: - if not utils.is_torch_available(): raise MissingDependencyError + if not utils.is_torch_available(): raise MissingDependencyError except MissingDependencyError: - from .utils import dummy_pt_objects - _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] + from .utils import dummy_pt_objects + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models.flan_t5"].extend(["FlanT5"]) - _import_structure["models.dolly_v2"].extend(["DollyV2"]) - _import_structure["models.starcoder"].extend(["StarCoder"]) - _import_structure["models.stablelm"].extend(["StableLM"]) - _import_structure["models.opt"].extend(["OPT"]) - _import_structure["models.gpt_neox"].extend(["GPTNeoX"]) - _import_structure["models.llama"].extend(["Llama"]) - _import_structure["models.auto"].extend(["AutoLLM", "MODEL_MAPPING"]) + _import_structure["models.flan_t5"].extend(["FlanT5"]) + _import_structure["models.dolly_v2"].extend(["DollyV2"]) + _import_structure["models.starcoder"].extend(["StarCoder"]) + _import_structure["models.stablelm"].extend(["StableLM"]) + _import_structure["models.opt"].extend(["OPT"]) + _import_structure["models.gpt_neox"].extend(["GPTNeoX"]) + _import_structure["models.llama"].extend(["Llama"]) + _import_structure["models.auto"].extend(["AutoLLM", "MODEL_MAPPING"]) try: - if not utils.is_vllm_available(): raise MissingDependencyError + if not utils.is_vllm_available(): raise MissingDependencyError except MissingDependencyError: - from .utils import dummy_vllm_objects - _import_structure["utils.dummy_vllm_objects"] = [name for name in dir(dummy_vllm_objects) if not name.startswith("_")] + from .utils import dummy_vllm_objects + _import_structure["utils.dummy_vllm_objects"] = [name for name in dir(dummy_vllm_objects) if not name.startswith("_")] else: - _import_structure["models.llama"].extend(["VLLMLlama"]) - _import_structure["models.auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"]) + _import_structure["models.llama"].extend(["VLLMLlama"]) + _import_structure["models.auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"]) try: - if not utils.is_flax_available(): raise MissingDependencyError + if not utils.is_flax_available(): raise MissingDependencyError except MissingDependencyError: - from .utils import dummy_flax_objects - _import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")] + from .utils import dummy_flax_objects + _import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")] else: - _import_structure["models.flan_t5"].extend(["FlaxFlanT5"]) - _import_structure["models.opt"].extend(["FlaxOPT"]) - _import_structure["models.auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"]) + _import_structure["models.flan_t5"].extend(["FlaxFlanT5"]) + _import_structure["models.opt"].extend(["FlaxOPT"]) + _import_structure["models.auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"]) try: - if not utils.is_tf_available(): raise MissingDependencyError + if not utils.is_tf_available(): raise MissingDependencyError except MissingDependencyError: - from .utils import dummy_tf_objects - _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")] + from .utils import dummy_tf_objects + _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")] else: - _import_structure["models.flan_t5"].extend(["TFFlanT5"]) - _import_structure["models.opt"].extend(["TFOPT"]) - _import_structure["models.auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"]) + _import_structure["models.flan_t5"].extend(["TFFlanT5"]) + _import_structure["models.opt"].extend(["TFOPT"]) + _import_structure["models.auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"]) # declaration for OpenLLM-related modules if t.TYPE_CHECKING: - from . import bundle as bundle - from . import cli as cli - from . import client as client - from . import exceptions as exceptions - from . import models as models - from . import playground as playground - from . import serialisation as serialisation - from . import testing as testing + from . import bundle as bundle + from . import cli as cli + from . import client as client + from . import exceptions as exceptions + from . import models as models + from . import playground as playground + from . import serialisation as serialisation + from . import testing as testing - # Specific types import - from ._configuration import LLMConfig as LLMConfig - from ._generation import StopOnTokens as StopOnTokens - from ._generation import StopSequenceCriteria as StopSequenceCriteria - from ._llm import LLM as LLM - from ._llm import LLMRunnable as LLMRunnable - from ._llm import LLMRunner as LLMRunner - from ._llm import Runner as Runner - from ._quantisation import infer_quantisation_config as infer_quantisation_config - from ._schema import EmbeddingsOutput as EmbeddingsOutput - from ._schema import GenerationInput as GenerationInput - from ._schema import GenerationOutput as GenerationOutput - from ._schema import MetadataOutput as MetadataOutput - from ._schema import unmarshal_vllm_outputs as unmarshal_vllm_outputs - from .cli.entrypoint import build as build - from .cli.entrypoint import import_model as import_model - from .cli.entrypoint import list_models as list_models - from .cli.entrypoint import start as start - from .cli.entrypoint import start_grpc as start_grpc - from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING - from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES - from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES - from .models.auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES - from .models.auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES - from .models.auto import AutoConfig as AutoConfig - from .models.baichuan import BaichuanConfig as BaichuanConfig - from .models.chatglm import ChatGLMConfig as ChatGLMConfig - from .models.dolly_v2 import DollyV2Config as DollyV2Config - from .models.falcon import FalconConfig as FalconConfig - from .models.flan_t5 import FlanT5Config as FlanT5Config - from .models.gpt_neox import GPTNeoXConfig as GPTNeoXConfig - from .models.llama import LlamaConfig as LlamaConfig - from .models.mpt import MPTConfig as MPTConfig - from .models.opt import OPTConfig as OPTConfig - from .models.stablelm import StableLMConfig as StableLMConfig - from .models.starcoder import StarCoderConfig as StarCoderConfig - from .serialisation import ggml as ggml - from .serialisation import transformers as transformers - from .utils import infer_auto_class as infer_auto_class + # Specific types import + from ._configuration import LLMConfig as LLMConfig + from ._generation import StopOnTokens as StopOnTokens + from ._generation import StopSequenceCriteria as StopSequenceCriteria + from ._llm import LLM as LLM + from ._llm import LLMRunnable as LLMRunnable + from ._llm import LLMRunner as LLMRunner + from ._llm import Runner as Runner + from ._quantisation import infer_quantisation_config as infer_quantisation_config + from ._schema import EmbeddingsOutput as EmbeddingsOutput + from ._schema import GenerationInput as GenerationInput + from ._schema import GenerationOutput as GenerationOutput + from ._schema import MetadataOutput as MetadataOutput + from ._schema import unmarshal_vllm_outputs as unmarshal_vllm_outputs + from .cli.entrypoint import build as build + from .cli.entrypoint import import_model as import_model + from .cli.entrypoint import list_models as list_models + from .cli.entrypoint import start as start + from .cli.entrypoint import start_grpc as start_grpc + from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING + from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES + from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES + from .models.auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES + from .models.auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES + from .models.auto import AutoConfig as AutoConfig + from .models.baichuan import BaichuanConfig as BaichuanConfig + from .models.chatglm import ChatGLMConfig as ChatGLMConfig + from .models.dolly_v2 import DollyV2Config as DollyV2Config + from .models.falcon import FalconConfig as FalconConfig + from .models.flan_t5 import FlanT5Config as FlanT5Config + from .models.gpt_neox import GPTNeoXConfig as GPTNeoXConfig + from .models.llama import LlamaConfig as LlamaConfig + from .models.mpt import MPTConfig as MPTConfig + from .models.opt import OPTConfig as OPTConfig + from .models.stablelm import StableLMConfig as StableLMConfig + from .models.starcoder import StarCoderConfig as StarCoderConfig + from .serialisation import ggml as ggml + from .serialisation import transformers as transformers + from .utils import infer_auto_class as infer_auto_class + # NOTE: torch and cpm_kernels + try: + if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError + except MissingDependencyError: + from .utils.dummy_pt_and_cpm_kernels_objects import * + else: + from .models.baichuan import Baichuan as Baichuan + from .models.chatglm import ChatGLM as ChatGLM + # NOTE: torch and einops + try: + if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError + except MissingDependencyError: + from .utils.dummy_pt_and_einops_objects import * + else: + from .models.falcon import Falcon as Falcon + # NOTE: torch and triton + try: + if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError + except MissingDependencyError: + from .utils.dummy_pt_and_triton_objects import * + else: + from .models.mpt import MPT as MPT + try: + if not utils.is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + from .utils.dummy_pt_objects import * + else: + from .models.auto import MODEL_MAPPING as MODEL_MAPPING + from .models.auto import AutoLLM as AutoLLM + from .models.dolly_v2 import DollyV2 as DollyV2 + from .models.flan_t5 import FlanT5 as FlanT5 + from .models.gpt_neox import GPTNeoX as GPTNeoX + from .models.llama import Llama as Llama + from .models.opt import OPT as OPT + from .models.stablelm import StableLM as StableLM + from .models.starcoder import StarCoder as StarCoder + try: + if not utils.is_vllm_available(): raise MissingDependencyError + except MissingDependencyError: + from .utils.dummy_vllm_objects import * + else: + from .models.auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING + from .models.auto import AutoVLLM as AutoVLLM + from .models.llama import VLLMLlama as VLLMLlama + from .models.opt import VLLMOPT as VLLMOPT + try: + if not utils.is_flax_available(): raise MissingDependencyError + except MissingDependencyError: + from .utils.dummy_flax_objects import * + else: + from .models.auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING + from .models.auto import AutoFlaxLLM as AutoFlaxLLM + from .models.flan_t5 import FlaxFlanT5 as FlaxFlanT5 + from .models.opt import FlaxOPT as FlaxOPT + try: + if not utils.is_tf_available(): raise MissingDependencyError + except MissingDependencyError: + from .utils.dummy_tf_objects import * + else: + from .models.auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING + from .models.auto import AutoTFLLM as AutoTFLLM + from .models.flan_t5 import TFFlanT5 as TFFlanT5 + from .models.opt import TFOPT as TFOPT - # NOTE: torch and cpm_kernels - try: - if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError - except MissingDependencyError: - from .utils.dummy_pt_and_cpm_kernels_objects import * - else: - from .models.baichuan import Baichuan as Baichuan - from .models.chatglm import ChatGLM as ChatGLM - - # NOTE: torch and einops - try: - if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError - except MissingDependencyError: - from .utils.dummy_pt_and_einops_objects import * - else: - from .models.falcon import Falcon as Falcon - - # NOTE: torch and triton - try: - if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError - except MissingDependencyError: - from .utils.dummy_pt_and_triton_objects import * - else: - from .models.mpt import MPT as MPT - - try: - if not utils.is_torch_available(): raise MissingDependencyError - except MissingDependencyError: - from .utils.dummy_pt_objects import * - else: - from .models.auto import MODEL_MAPPING as MODEL_MAPPING - from .models.auto import AutoLLM as AutoLLM - from .models.dolly_v2 import DollyV2 as DollyV2 - from .models.flan_t5 import FlanT5 as FlanT5 - from .models.gpt_neox import GPTNeoX as GPTNeoX - from .models.llama import Llama as Llama - from .models.opt import OPT as OPT - from .models.stablelm import StableLM as StableLM - from .models.starcoder import StarCoder as StarCoder - - try: - if not utils.is_vllm_available(): raise MissingDependencyError - except MissingDependencyError: - from .utils.dummy_vllm_objects import * - else: - from .models.auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING - from .models.auto import AutoVLLM as AutoVLLM - from .models.llama import VLLMLlama as VLLMLlama - from .models.opt import VLLMOPT as VLLMOPT - - try: - if not utils.is_flax_available(): raise MissingDependencyError - except MissingDependencyError: - from .utils.dummy_flax_objects import * - else: - from .models.auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING - from .models.auto import AutoFlaxLLM as AutoFlaxLLM - from .models.flan_t5 import FlaxFlanT5 as FlaxFlanT5 - from .models.opt import FlaxOPT as FlaxOPT - - try: - if not utils.is_tf_available(): raise MissingDependencyError - except MissingDependencyError: - from .utils.dummy_tf_objects import * - else: - from .models.auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING - from .models.auto import AutoTFLLM as AutoTFLLM - from .models.flan_t5 import TFFlanT5 as TFFlanT5 - from .models.opt import TFOPT as TFOPT - -else: sys.modules[__name__] = utils.LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, doc=__doc__, - extra_objects={ +else: + sys.modules[__name__] = utils.LazyModule( + __name__, + globals()["__file__"], _import_structure, module_spec=__spec__, doc=__doc__, extra_objects={ # The below is a special mapping that allows openllm to be used as a dictionary. # This is purely for convenience sake, and should not be used in performance critcal # code. This is also not considered as a public API. "__openllm_special__": {"flax": "AutoFlaxLLM", "tf": "AutoTFLLM", "pt": "AutoLLM", "vllm": "AutoVLLM"}, - }) + } + ) diff --git a/src/openllm/__main__.py b/src/openllm/__main__.py index f999cc5d..cce3bef8 100644 --- a/src/openllm/__main__.py +++ b/src/openllm/__main__.py @@ -20,5 +20,5 @@ To start any OpenLLM model: openllm start --options ... """ if __name__ == "__main__": - from openllm.cli.entrypoint import cli - cli() + from openllm.cli.entrypoint import cli + cli() diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index c5f478fc..73dba4dd 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -88,15 +88,15 @@ from .utils.import_utils import BACKENDS_MAPPING # so that it can register # correct overloads to typing registry if sys.version_info[:2] >= (3, 11): - from typing import NotRequired - from typing import Required - from typing import dataclass_transform - from typing import overload + from typing import NotRequired + from typing import Required + from typing import dataclass_transform + from typing import overload else: - from typing_extensions import NotRequired - from typing_extensions import Required - from typing_extensions import dataclass_transform - from typing_extensions import overload + from typing_extensions import NotRequired + from typing_extensions import Required + from typing_extensions import dataclass_transform + from typing_extensions import overload # NOTE: Using internal API from attr here, since we are actually # allowing subclass of openllm.LLMConfig to become 'attrs'-ish @@ -110,885 +110,590 @@ _T = t.TypeVar("_T") LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"] if t.TYPE_CHECKING: - import click - import peft - import vllm + import click + import peft + import vllm - import transformers - from transformers.generation.beam_constraints import Constraint + import transformers + from transformers.generation.beam_constraints import Constraint - from ._types import AnyCallable - from ._types import At + from ._types import AnyCallable + from ._types import At - DictStrAny = dict[str, t.Any] - ListStr = list[str] - ItemgetterAny = itemgetter[t.Any] - FieldTransformers = t.Callable[[_T, list[attr.Attribute[t.Any]]], list[attr.Attribute[t.Any]]] + DictStrAny = dict[str, t.Any] + ListStr = list[str] + ItemgetterAny = itemgetter[t.Any] + FieldTransformers = t.Callable[[_T, list[attr.Attribute[t.Any]]], list[attr.Attribute[t.Any]]] else: - Constraint = t.Any - ListStr = list - DictStrAny = dict - ItemgetterAny = itemgetter + Constraint = t.Any + ListStr = list + DictStrAny = dict + ItemgetterAny = itemgetter - vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm") - transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") - peft = openllm.utils.LazyLoader("peft", globals(), "peft") + vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm") + transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") + peft = openllm.utils.LazyLoader("peft", globals(), "peft") __all__ = ["LLMConfig"] logger = logging.getLogger(__name__) - -config_merger = Merger( - # merge dicts - type_strategies=[(DictStrAny, "merge")], - # override all other types - fallback_strategies=["override"], - # override conflicting types - type_conflict_strategies=["override"], -) - +config_merger = Merger([(DictStrAny, "merge")], ["override"], ["override"]) # case insensitive, but rename to conform with type class _PeftEnumMeta(enum.EnumMeta): - def __getitem__(self, __key: str | t.Any, /) -> t.Any: - if isinstance(__key, str): __key = inflection.underscore(__key).upper() - return self._member_map_[__key] - + def __getitem__(self, __key: str | t.Any, /) -> t.Any: + if isinstance(__key, str): __key = inflection.underscore(__key).upper() + return self._member_map_[__key] # vendorred from peft.utils.config.PeftType since we don't have hard dependency on peft # see https://github.com/huggingface/peft/blob/main/src/peft/utils/config.py class PeftType(str, enum.Enum, metaclass=_PeftEnumMeta): - PROMPT_TUNING = "PROMPT_TUNING" - P_TUNING = "P_TUNING" - PREFIX_TUNING = "PREFIX_TUNING" - LORA = "LORA" - ADALORA = "ADALORA" - ADAPTION_PROMPT = "ADAPTION_PROMPT" - IA3 = "IA3" - - @classmethod - def _missing_(cls, value: object) -> enum.Enum | None: - if isinstance(value, str): - normalized = inflection.underscore(value).upper() - if normalized in cls._member_map_: return cls._member_map_[normalized] - return None - @classmethod - def supported(cls) -> set[str]: return {inflection.underscore(v.value) for v in cls} - def to_str(self) -> str: return self.value - @staticmethod - def get(__key: str | t.Any, /) -> PeftType: return PeftType[__key] # type-safe getitem. - + PROMPT_TUNING = "PROMPT_TUNING" + P_TUNING = "P_TUNING" + PREFIX_TUNING = "PREFIX_TUNING" + LORA = "LORA" + ADALORA = "ADALORA" + ADAPTION_PROMPT = "ADAPTION_PROMPT" + IA3 = "IA3" + @classmethod + def _missing_(cls, value: object) -> enum.Enum | None: + if isinstance(value, str): + normalized = inflection.underscore(value).upper() + if normalized in cls._member_map_: return cls._member_map_[normalized] + return None + @classmethod + def supported(cls) -> set[str]: return {inflection.underscore(v.value) for v in cls} + def to_str(self) -> str: return self.value + @staticmethod + def get(__key: str | t.Any, /) -> PeftType: return PeftType[__key] # type-safe getitem. _PEFT_TASK_TYPE_TARGET_MAPPING = {"causal_lm": "CAUSAL_LM", "seq2seq_lm": "SEQ_2_SEQ_LM"} if t.TYPE_CHECKING: - AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning"] + AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning"] else: - AdapterType = str + AdapterType = str _object_setattr = object.__setattr__ - def _adapter_converter(value: AdapterType | str | PeftType | None) -> PeftType: - if value is None: - raise ValueError("'AdapterType' cannot be None.") - if isinstance(value, PeftType): - return value - if value not in PeftType.supported(): - raise ValueError(f"Given '{value}' is not a supported adapter type.") - return PeftType.get(value) - + if value is None: raise ValueError("'AdapterType' cannot be None.") + if isinstance(value, PeftType): return value + if value not in PeftType.supported(): raise ValueError(f"Given '{value}' is not a supported adapter type.") + return PeftType.get(value) @attr.define(slots=True) class FineTuneConfig: - """FineTuneConfig defines a default value for fine-tuning this any given LLM. + """FineTuneConfig defines a default value for fine-tuning this any given LLM. - For example: + For example: - ```python - class FalconConfig(openllm.LLMConfig): + ```python + class FalconConfig(openllm.LLMConfig): + __config__ = { + "fine_tune_strategies": ( + { + "adapter_type": "lora", + "r": 64, + "lora_alpha": 16, + "lora_dropout": 0.1, + "bias": "none", + "target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], + }, + ), + } + ``` - __config__ = { - "fine_tune_strategies": ( - { - "adapter_type": "lora", - "r": 64, - "lora_alpha": 16, - "lora_dropout": 0.1, - "bias": "none", - "target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], - }, - ), - } - ``` + This is a lower level API that leverage `peft` as well as openllm.LLMConfig to create default + and customization + """ - This is a lower level API that leverage `peft` as well as openllm.LLMConfig to create default - and customization - """ + if t.TYPE_CHECKING and not MYPY: + # fmt: off + # The following type stubs makes __init__ aware of attrs internal type converter. + @overload + def __init__(self, adapter_type: AdapterType = ..., adapter_config: dict[str, t.Any] = ..., inference_mode: bool = ..., llm_config_class: type[LLMConfig] = ...) -> None: ... + @overload + def __init__(self, adapter_type: PeftType = ..., adapter_config: dict[str, t.Any] = ..., inference_mode: bool = ..., llm_config_class: type[LLMConfig] = ...) -> None: ... + # The below should be generated via attrs. Only here to conform with pyright strict checking. + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: ... + # fmt: on - if t.TYPE_CHECKING and not MYPY: - # The following type stubs makes __init__ aware of attrs - # internal type converter. - @overload - def __init__( - self, - adapter_type: AdapterType = ..., - adapter_config: dict[str, t.Any] = ..., - inference_mode: bool = ..., - llm_config_class: type[LLMConfig] = ..., - ) -> None: - ... + adapter_type: PeftType = dantic.Field("lora", description=f"The type of adapter to use for fine-tuning. Available supported methods: {PeftType.supported()}, default to 'lora'", use_default_converter=False, converter=_adapter_converter) + adapter_config: t.Dict[str, t.Any] = dantic.Field(None, description="The configuration for the adapter. The content of the dict depends on the adapter type.", validator=attr.validators.optional(attr.validators.instance_of(dict)), converter=attr.converters.default_if_none(factory=dict), use_default_converter=False) + inference_mode: bool = dantic.Field(False, description="Whether to use this Adapter for inference", use_default_converter=False) + llm_config_class: type[LLMConfig] = dantic.Field(None, description="The reference class to openllm.LLMConfig", use_default_converter=False) - @overload - def __init__( - self, - adapter_type: PeftType = ..., - adapter_config: dict[str, t.Any] = ..., - inference_mode: bool = ..., - llm_config_class: type[LLMConfig] = ..., - ) -> None: - ... + @requires_dependencies("peft", extra="fine-tune") + def to_peft_config(self) -> peft.PeftConfig: + adapter_config = self.adapter_config.copy() + # no need for peft_type since it is internally managed by OpenLLM and PEFT + if "peft_type" in adapter_config: adapter_config.pop("peft_type") + # respect user set task_type if it is passed, otherwise use one managed by OpenLLM + task_type, inference_mode = adapter_config.pop("task_type", peft.TaskType[self.llm_config_class.peft_task_type()]), adapter_config.pop("inference_mode", self.inference_mode) + return peft.PEFT_TYPE_TO_CONFIG_MAPPING[self.adapter_type.to_str()](task_type=task_type, inference_mode=inference_mode, **adapter_config) - # The below should be generated via attrs. Only here to conform with pyright strict checking. - def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: - ... + def train(self) -> FineTuneConfig: _object_setattr(self, "inference_mode", False); return self + def eval(self) -> FineTuneConfig: _object_setattr(self, "inference_mode", True); return self - adapter_type: PeftType = dantic.Field( - "lora", - description=f"The type of adapter to use for fine-tuning. Available supported methods: {PeftType.supported()}, default to 'lora'", - use_default_converter=False, - converter=_adapter_converter, - ) - adapter_config: t.Dict[str, t.Any] = dantic.Field( - None, - description="The configuration for the adapter. The content of the dict depends on the adapter type.", - validator=attr.validators.optional(attr.validators.instance_of(dict)), - converter=attr.converters.default_if_none(factory=dict), - use_default_converter=False, - ) - inference_mode: bool = dantic.Field( - False, - description="Whether to use this Adapter for inference", - use_default_converter=False, - ) - llm_config_class: type[LLMConfig] = dantic.Field( - None, - description="The reference class to openllm.LLMConfig", - use_default_converter=False, - ) + def with_config(self, **attrs: t.Any) -> FineTuneConfig: + adapter_type, inference_mode = attrs.pop("adapter_type", self.adapter_type), attrs.get("inference_mode", self.inference_mode) + if "llm_config_class" in attrs: raise ForbiddenAttributeError("'llm_config_class' should not be passed when using 'with_config'.") + return attr.evolve(self, adapter_type=adapter_type, inference_mode=inference_mode, adapter_config=config_merger.merge(self.adapter_config, attrs)) - @requires_dependencies("peft", extra="fine-tune") - def to_peft_config(self) -> peft.PeftConfig: - # makes a copy to correctly set all modules - adapter_config = self.adapter_config.copy() - if "peft_type" in adapter_config: - # no need for peft_type since it is internally - # managed by OpenLLM and PEFT - adapter_config.pop("peft_type") + @classmethod + def make_adapter_config_class(cls, adapter_type: AdapterType, llm_config_class: type[LLMConfig], /, *, docs: str | None = None, **attrs: t.Any) -> type[FineTuneConfig]: + """A loose codegen to create default subclass for given adapter config type.""" + _new_default = {"adapter_type": PeftType[adapter_type], "adapter_config": attrs, "llm_config_class": llm_config_class} - # respect user set task_type if it is passed, otherwise use one managed by OpenLLM - task_type = adapter_config.pop("task_type", peft.TaskType[self.llm_config_class.peft_task_type()]) - inference_mode = adapter_config.pop("inference_mode", self.inference_mode) - - return peft.PEFT_TYPE_TO_CONFIG_MAPPING[self.adapter_type.to_str()]( - task_type=task_type, - inference_mode=inference_mode, - **adapter_config, - ) - - def train(self) -> FineTuneConfig: - _object_setattr(self, "inference_mode", False) - return self - - def eval(self) -> FineTuneConfig: - _object_setattr(self, "inference_mode", True) - return self - - def with_config(self, **attrs: t.Any) -> FineTuneConfig: - """Create a new instance of FineTuneConfig with the given attributes.""" - adapter_type = attrs.pop("adapter_type", self.adapter_type) - inference_mode = attrs.get("inference_mode", self.inference_mode) - if "llm_config_class" in attrs: - raise ForbiddenAttributeError("'llm_config_class' should not be passed when using 'with_config'.") - return attr.evolve( - self, - adapter_type=adapter_type, - inference_mode=inference_mode, - adapter_config=config_merger.merge(self.adapter_config, attrs), - ) - - @classmethod - def make_adapter_config_class(cls, adapter_type: AdapterType, llm_config_class: type[LLMConfig], /, *, docs: str | None = None, **attrs: t.Any) -> type[FineTuneConfig]: - """A loose codegen to create default subclass for given adapter config type. - - This is used to make adapter subclass - """ - _new_default = {"adapter_type": PeftType[adapter_type], "adapter_config": attrs, "llm_config_class": llm_config_class} - def transformers(_: type[t.Any], fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]: - transformed: list[attr.Attribute[t.Any]] = [] - for f in fields: - if f.name in _new_default: transformed.append(f.evolve(default=_new_default[f.name])) - else: transformed.append(f) - return transformed - - klass = attr.make_class( - f"{inflection.camelize(adapter_type)}{llm_config_class.__name__}", - [], - bases=(cls,), - slots=True, - weakref_slot=True, - frozen=True, - repr=True, - collect_by_mro=True, - field_transformer=transformers, - ) - if docs is not None: klass.__doc__ = docs - - return klass + def transformers(_: type[t.Any], fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]: + transformed: list[attr.Attribute[t.Any]] = [] + for f in fields: + if f.name in _new_default: transformed.append(f.evolve(default=_new_default[f.name])) + else: transformed.append(f) + return transformed + klass = attr.make_class(f"{inflection.camelize(adapter_type)}{llm_config_class.__name__}", [], bases=(cls,), slots=True, weakref_slot=True, frozen=True, repr=True, collect_by_mro=True, field_transformer=transformers) + if docs is not None: klass.__doc__ = docs + return klass @attr.frozen(slots=True, repr=False, init=False) class GenerationConfig(ReprMixin): - """GenerationConfig is the attrs-compatible version of ``transformers.GenerationConfig``, with some additional validation and environment constructor. + """GenerationConfig is the attrs-compatible version of ``transformers.GenerationConfig``, with some additional validation and environment constructor. - Note that we always set `do_sample=True`. This class is not designed to be used directly, rather - to be used conjunction with LLMConfig. The instance of the generation config can then be accessed - via ``LLMConfig.generation_config``. - """ + Note that we always set `do_sample=True`. This class is not designed to be used directly, rather + to be used conjunction with LLMConfig. The instance of the generation config can then be accessed + via ``LLMConfig.generation_config``. + """ - # NOTE: parameters for controlling the length of the output - max_new_tokens: int = dantic.Field( - 20, - ge=0, - description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.", - ) - min_length: int = dantic.Field( - 0, - ge=0, - description="""The minimum length of the sequence to be generated. Corresponds to the length of the - input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.""", - ) - min_new_tokens: int = dantic.Field( - description="The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.", - ) - early_stopping: bool = dantic.Field( - False, - description="""Controls the stopping condition for beam-based methods, like beam-search. It accepts the + # NOTE: parameters for controlling the length of the output + max_new_tokens: int = dantic.Field(20, ge=0, description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.") + min_length: int = dantic.Field(0, ge=0, description="""The minimum length of the sequence to be generated. Corresponds to the length of the + input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.""") + min_new_tokens: int = dantic.Field(description="The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.") + early_stopping: bool = dantic.Field( + False, description="""Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm) """, - ) - max_time: float = dantic.Field( - description="""The maximum amount of time you allow the computation to run for in seconds. generation will - still finish the current pass after allocated time has been passed.""", - ) + ) + max_time: float = dantic.Field(description="""The maximum amount of time you allow the computation to run for in seconds. generation will + still finish the current pass after allocated time has been passed.""") - # NOTE: Parameters for controling generaiton strategies - num_beams: int = dantic.Field(1, description="Number of beams for beam search. 1 means no beam search.") - num_beam_groups: int = dantic.Field( - 1, - description="""Number of groups to divide `num_beams` into in order to ensure diversity among different - groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.""", - ) - penalty_alpha: float = dantic.Field( - description="""The values balance the model confidence and the degeneration penalty in - contrastive search decoding.""", - ) - use_cache: bool = dantic.Field( - True, - description="""Whether or not the model should use the past last - key/values attentions (if applicable to the model) to speed up decoding.""", - ) + # NOTE: Parameters for controling generaiton strategies + num_beams: int = dantic.Field(1, description="Number of beams for beam search. 1 means no beam search.") + num_beam_groups: int = dantic.Field(1, description="""Number of groups to divide `num_beams` into in order to ensure diversity among different + groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.""") + penalty_alpha: float = dantic.Field(description="""The values balance the model confidence and the degeneration penalty in + contrastive search decoding.""") + use_cache: bool = dantic.Field(True, description="""Whether or not the model should use the past last + key/values attentions (if applicable to the model) to speed up decoding.""") - # NOTE: Parameters for manipulation of the model output logits - temperature: float = dantic.Field( - 1.0, ge=0.0, le=1.0, description="The value used to modulate the next token probabilities." - ) - top_k: int = dantic.Field( - 50, description="The number of highest probability vocabulary tokens to keep for top-k-filtering." - ) - top_p: float = dantic.Field( - 1.0, - description="""If set to float < 1, only the smallest set of most probable tokens with - probabilities that add up to `top_p` or higher are kept for generation.""", - ) - typical_p: float = dantic.Field( - 1.0, - description="""Local typicality measures how similar the conditional probability of predicting a target + # NOTE: Parameters for manipulation of the model output logits + temperature: float = dantic.Field(1.0, ge=0.0, le=1.0, description="The value used to modulate the next token probabilities.") + top_k: int = dantic.Field(50, description="The number of highest probability vocabulary tokens to keep for top-k-filtering.") + top_p: float = dantic.Field(1.0, description="""If set to float < 1, only the smallest set of most probable tokens with + probabilities that add up to `top_p` or higher are kept for generation.""") + typical_p: float = dantic.Field( + 1.0, description="""Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details. """, - ) - epsilon_cutoff: float = dantic.Field( - 0.0, - description="""\ + ) + epsilon_cutoff: float = dantic.Field( + 0.0, description="""\ If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details. """, - ) - eta_cutoff: float = dantic.Field( - 0.0, - description="""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. + ) + eta_cutoff: float = dantic.Field( + 0.0, description="""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details. """, - ) - diversity_penalty: float = dantic.Field( - 0.0, - description="""This value is subtracted from a beam's score if it generates a token same - as any beam from other group at a particular time. Note that `diversity_penalty` is only - effective if `group beam search` is enabled. - """, - ) - repetition_penalty: float = dantic.Field( - 1.0, - description="""The parameter for repetition penalty. 1.0 means no penalty. - See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.""", - ) - encoder_repetition_penalty: float = dantic.Field( - 1.0, - description="""The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are - not in the original input. 1.0 means no penalty.""", - ) - length_penalty: float = dantic.Field( - 1.0, - description="""Exponential penalty to the length that is used with beam-based generation. It is applied + ) + diversity_penalty: float = dantic.Field(0.0, description="""This value is subtracted from a beam's score if it generates a token same + as any beam from other group at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. + """,) + repetition_penalty: float = dantic.Field(1.0, description="""The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.""") + encoder_repetition_penalty: float = dantic.Field(1.0, description="""The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the original input. 1.0 means no penalty.""") + length_penalty: float = dantic.Field( + 1.0, description="""Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences. """, - ) - no_repeat_ngram_size: int = dantic.Field( - 0, description="If set to int > 0, all ngrams of that size can only occur once." - ) - bad_words_ids: t.List[t.List[int]] = dantic.Field( - description="""List of token ids that are not allowed to be generated. In order to get the token ids + ) + no_repeat_ngram_size: int = dantic.Field(0, description="If set to int > 0, all ngrams of that size can only occur once.") + bad_words_ids: t.List[t.List[int]] = dantic.Field(description="""List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`. - """, - ) + """,) - # NOTE: t.Union is not yet supported on CLI, but the environment variable should already be available. - force_words_ids: t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]] = dantic.Field( - description="""List of token ids that must be generated. If given a `List[List[int]]`, this is treated + # NOTE: t.Union is not yet supported on CLI, but the environment variable should already be available. + force_words_ids: t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]] = dantic.Field( + description="""List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. """, - ) - renormalize_logits: bool = dantic.Field( - False, - description="""Whether to renormalize the logits after applying all the logits processors or warpers + ) + renormalize_logits: bool = dantic.Field(False, description="""Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. - """, - ) - constraints: t.List[Constraint] = dantic.Field( - description="""Custom constraints that can be added to the generation to ensure that the output + """,) + constraints: t.List[Constraint] = dantic.Field(description="""Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by ``Constraint`` objects, in the most sensible way possible. - """, - ) - forced_bos_token_id: int = dantic.Field( - description="""The id of the token to force as the first generated token after the + """) + forced_bos_token_id: int = dantic.Field(description="""The id of the token to force as the first generated token after the ``decoder_start_token_id``. Useful for multilingual models like [mBART](https://huggingface.co/docs/transformers/model_doc/mbart) where the first generated token needs to be the target language token. - """, - ) - forced_eos_token_id: t.Union[int, t.List[int]] = dantic.Field( - description="""The id of the token to force as the last generated token when `max_length` is reached. - Optionally, use a list to set multiple *end-of-sequence* tokens.""", - ) - remove_invalid_values: bool = dantic.Field( - False, - description="""Whether to remove possible *nan* and *inf* outputs of the model to prevent the - generation method to crash. Note that using `remove_invalid_values` can slow down generation.""", - ) - exponential_decay_length_penalty: t.Tuple[int, float] = dantic.Field( - description="""This tuple adds an exponentially increasing length penalty, after a certain amount of tokens + """,) + forced_eos_token_id: t.Union[int, t.List[int]] = dantic.Field(description="""The id of the token to force as the last generated token when `max_length` is reached. + Optionally, use a list to set multiple *end-of-sequence* tokens.""") + remove_invalid_values: bool = dantic.Field(False, description="""Whether to remove possible *nan* and *inf* outputs of the model to prevent the + generation method to crash. Note that using `remove_invalid_values` can slow down generation.""") + exponential_decay_length_penalty: t.Tuple[ + int, float] = dantic.Field(description="""This tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay - """, - ) - suppress_tokens: t.List[int] = dantic.Field( - description="""A list of tokens that will be suppressed at generation. The `SupressTokens` logit + """,) + suppress_tokens: t.List[int] = dantic.Field(description="""A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their log probs to `-inf` so that they are not sampled. - """, - ) - begin_suppress_tokens: t.List[int] = dantic.Field( - description="""A list of tokens that will be suppressed at the beginning of the generation. The + """) + begin_suppress_tokens: t.List[int] = dantic.Field(description="""A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. - """, - ) - forced_decoder_ids: t.List[t.List[int]] = dantic.Field( - description="""A list of pairs of integers which indicates a mapping from generation indices to token indices + """) + forced_decoder_ids: t.List[t.List[int]] = dantic.Field(description="""A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123. - """, - ) + """,) - # NOTE: Parameters that define the output variables of `generate` - num_return_sequences: int = dantic.Field( - 1, description="The number of independently computed returned sequences for each element in the batch." - ) - output_attentions: bool = dantic.Field( - False, - description="""Whether or not to return the attentions tensors of all attention layers. - See `attentions` under returned tensors for more details. """, - ) - output_hidden_states: bool = dantic.Field( - False, - description="""Whether or not to return the hidden states of all layers. - See `hidden_states` under returned tensors for more details. - """, - ) - output_scores: bool = dantic.Field( - False, - description="""Whether or not to return the prediction scores. See `scores` under returned - tensors for more details.""", - ) + # NOTE: Parameters that define the output variables of `generate` + num_return_sequences: int = dantic.Field(1, description="The number of independently computed returned sequences for each element in the batch.") + output_attentions: bool = dantic.Field(False, description="""Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details.""") + output_hidden_states: bool = dantic.Field(False, description="""Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details.""") + output_scores: bool = dantic.Field(False, description="""Whether or not to return the prediction scores. See `scores` under returned tensors for more details.""") - # NOTE: Special tokens that can be used at generation time - pad_token_id: int = dantic.Field(description="The id of the *padding* token.") - bos_token_id: int = dantic.Field(description="The id of the *beginning-of-sequence* token.") - eos_token_id: t.Union[int, t.List[int]] = dantic.Field( - description="""The id of the *end-of-sequence* token. Optionally, use a list to set - multiple *end-of-sequence* tokens.""", - ) + # NOTE: Special tokens that can be used at generation time + pad_token_id: int = dantic.Field(description="The id of the *padding* token.") + bos_token_id: int = dantic.Field(description="The id of the *beginning-of-sequence* token.") + eos_token_id: t.Union[int, t.List[int]] = dantic.Field(description="""The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.""") + # NOTE: Generation parameters exclusive to encoder-decoder models + encoder_no_repeat_ngram_size: int = dantic.Field(0, description="""If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`. """) + decoder_start_token_id: int = dantic.Field(description="""If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. """) - # NOTE: Generation parameters exclusive to encoder-decoder models - encoder_no_repeat_ngram_size: int = dantic.Field( - 0, - description="""If set to int > 0, all ngrams of that size that occur in the - `encoder_input_ids` cannot occur in the `decoder_input_ids`. - """, - ) - decoder_start_token_id: int = dantic.Field( - description="""If an encoder-decoder model starts decoding with a - different token than *bos*, the id of that token. - """, - ) + if t.TYPE_CHECKING and not MYPY: + # stubs this for pyright as mypy already has a attr plugin builtin + def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: + ... - if t.TYPE_CHECKING and not MYPY: - # stubs this for pyright as mypy already has a attr plugin builtin - def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: ... - def __init__(self, *, _internal: bool = False, **attrs: t.Any): - if not _internal: raise RuntimeError("GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config") - self.__attrs_init__(**attrs) - def __getitem__(self, item: str) -> t.Any: - if hasattr(self, item): return getattr(self, item) - raise KeyError(f"'{self.__class__.__name__}' has no attribute {item}.") - @property - def __repr_keys__(self) -> set[str]: return {i.name for i in attr.fields(self.__class__)} + def __init__(self, *, _internal: bool = False, **attrs: t.Any): + if not _internal: raise RuntimeError("GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config") + self.__attrs_init__(**attrs) + def __getitem__(self, item: str) -> t.Any: + if hasattr(self, item): return getattr(self, item) + raise KeyError(f"'{self.__class__.__name__}' has no attribute {item}.") + + @property + def __repr_keys__(self) -> set[str]: + return {i.name for i in attr.fields(self.__class__)} bentoml_cattr.register_unstructure_hook_factory( lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig), - lambda cls: make_dict_unstructure_fn( - cls, - bentoml_cattr, - # The below is the default, put here for strict annotations - _cattrs_omit_if_default=False, - _cattrs_use_linecache=True, - **{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)}, - ), -) - + lambda cls: make_dict_unstructure_fn(cls, bentoml_cattr, _cattrs_omit_if_default=False, _cattrs_use_linecache=True, + **{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)})) @attr.frozen(slots=True, repr=False, init=False) class SamplingParams(ReprMixin): - """SamplingParams is the attr-compatible version of ``vllm.SamplingParams``. It provides some utilities to also respect shared variables from ``openllm.LLMConfig``. + """SamplingParams is the attr-compatible version of ``vllm.SamplingParams``. It provides some utilities to also respect shared variables from ``openllm.LLMConfig``. - The following value will be parsed directly from ``openllm.LLMConfig``: - - temperature - - top_k - - top_p - - max_tokens -> max_new_tokens - """ - - n: int = dantic.Field(1, description="Number of output sequences to return for the given prompt.") - best_of: int = dantic.Field( - None, - description="""\ + The following value will be parsed directly from ``openllm.LLMConfig``: + - temperature + - top_k + - top_p + - max_tokens -> max_new_tokens + """ + n: int = dantic.Field(1, description="Number of output sequences to return for the given prompt.") + best_of: int = dantic.Field(None, description="""\ Number of output sequences that are generated from the prompt. From these `best_of` sequences, the top `n` sequences are returned. `best_of` must be greater than or equal to `n`. This is treated as the beam width when `use_beam_search` is True. By default, `best_of` is set to `n`. - """, - ) - presence_penalty: float = dantic.Field( - 0.0, - description="""\ - Float that penalizes new tokens based on whether they + """,) + presence_penalty: float = dantic.Field(0.0, description="""Float that penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens. - """, - ) - frequency_penalty: float = dantic.Field( - 0.0, - description="""\ + """,) + frequency_penalty: float = dantic.Field(0.0, description="""\ Float that penalizes new tokens based on their frequency in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens. - """, - ) - use_beam_search: bool = dantic.Field(False, description="Whether to use beam search instead of sampling.") - stop: t.List[str] = dantic.Field( - None, - description="""\ - List of strings that stop the generation when they are generated. - The returned output will not contain the stop strings.""", - ) - ignore_eos: bool = dantic.Field( - False, - description="""\ - Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. - """, - ) - logprobs: int = dantic.Field( - None, - description="""\ - Number of log probabilities to return per output token.""", - ) + """,) + use_beam_search: bool = dantic.Field(False, description="Whether to use beam search instead of sampling.") + stop: t.List[str] = dantic.Field(None, description="List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.") + ignore_eos: bool = dantic.Field(False, description="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.") + logprobs: int = dantic.Field(None, description="Number of log probabilities to return per output token.") - if t.TYPE_CHECKING and not MYPY: - max_tokens: int - temperature: float - top_k: int - top_p: float + if t.TYPE_CHECKING and not MYPY: + max_tokens: int + temperature: float + top_k: int + top_p: float - def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: - ... + def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: + ... - def __init__(self, *, _internal: bool = False, **attrs: t.Any): - if not _internal: raise RuntimeError("SamplingParams is not meant to be used directly, but you can access this via a LLMConfig.sampling_config or create one with 'SamplingParams.from_generation_config'") - _object_setattr(self, "max_tokens", attrs.pop("max_tokens", 16)) - _object_setattr(self, "temperature", attrs.pop("temperature", 1.0)) - _object_setattr(self, "top_k", attrs.pop("top_k", -1)) - _object_setattr(self, "top_p", attrs.pop("top_p", 1.0)) - self.__attrs_init__(**attrs) - def __getitem__(self, item: str) -> t.Any: - if hasattr(self, item): return getattr(self, item) - raise KeyError(f"'{self.__class__.__name__}' has no attribute {item}.") - @property - def __repr_keys__(self) -> set[str]: return {i.name for i in attr.fields(self.__class__)} - @classmethod - def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> t.Self: - """The main entrypoint for creating a SamplingParams from ``openllm.LLMConfig``.""" - stop = attrs.pop("stop", None) - if stop is not None and isinstance(stop, str): stop = [stop] - attrs["stop"] = stop - - if "max_tokens" in attrs and "max_new_tokens" in attrs: raise ValueError("Both 'max_tokens' and 'max_new_tokens' are passed. Make sure to only use one of them.") - temperature = first_not_none(attrs.pop("temperature", None), default=generation_config["temperature"]) - top_k = first_not_none(attrs.pop("top_k", None), default=generation_config["top_k"]) - top_p = first_not_none(attrs.pop("top_p", None), default=generation_config["top_p"]) - max_tokens = first_not_none(attrs.pop("max_tokens", None), attrs.pop("max_new_tokens", None), default=generation_config["max_new_tokens"]) - - return cls(_internal=True, temperature=temperature, top_k=top_k, top_p=top_p, max_tokens=max_tokens, **attrs) - @requires_dependencies("vllm", extra="vllm") - def to_vllm(self) -> vllm.SamplingParams: return vllm.SamplingParams(max_tokens=self.max_tokens, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, **bentoml_cattr.unstructure(self)) + def __init__(self, *, _internal: bool = False, **attrs: t.Any): + if not _internal: raise RuntimeError("SamplingParams is not meant to be used directly, but you can access this via a LLMConfig.sampling_config or create one with 'SamplingParams.from_generation_config'") + _object_setattr(self, "max_tokens", attrs.pop("max_tokens", 16)) + _object_setattr(self, "temperature", attrs.pop("temperature", 1.0)) + _object_setattr(self, "top_k", attrs.pop("top_k", -1)) + _object_setattr(self, "top_p", attrs.pop("top_p", 1.0)) + self.__attrs_init__(**attrs) + def __getitem__(self, item: str) -> t.Any: + if hasattr(self, item): return getattr(self, item) + raise KeyError(f"'{self.__class__.__name__}' has no attribute {item}.") + @property + def __repr_keys__(self) -> set[str]: return {i.name for i in attr.fields(self.__class__)} + @classmethod + def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> t.Self: + """The main entrypoint for creating a SamplingParams from ``openllm.LLMConfig``.""" + stop = attrs.pop("stop", None) + if stop is not None and isinstance(stop, str): stop = [stop] + attrs["stop"] = stop + if "max_tokens" in attrs and "max_new_tokens" in attrs: raise ValueError("Both 'max_tokens' and 'max_new_tokens' are passed. Make sure to only use one of them.") + temperature = first_not_none(attrs.pop("temperature", None), default=generation_config["temperature"]) + top_k = first_not_none(attrs.pop("top_k", None), default=generation_config["top_k"]) + top_p = first_not_none(attrs.pop("top_p", None), default=generation_config["top_p"]) + max_tokens = first_not_none(attrs.pop("max_tokens", None), attrs.pop("max_new_tokens", None), default=generation_config["max_new_tokens"]) + return cls(_internal=True, temperature=temperature, top_k=top_k, top_p=top_p, max_tokens=max_tokens, **attrs) + @requires_dependencies("vllm", extra="vllm") + def to_vllm(self) -> vllm.SamplingParams: return vllm.SamplingParams(max_tokens=self.max_tokens, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, **bentoml_cattr.unstructure(self)) bentoml_cattr.register_unstructure_hook_factory( lambda cls: attr.has(cls) and lenient_issubclass(cls, SamplingParams), - lambda cls: make_dict_unstructure_fn( - cls, - bentoml_cattr, - # The below is the default, put here for strict annotations - _cattrs_omit_if_default=False, - _cattrs_use_linecache=True, - **{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)}, - ), -) -bentoml_cattr.register_structure_hook_factory( - lambda cls: attr.has(cls) and lenient_issubclass(cls, SamplingParams), - lambda cls: make_dict_structure_fn( - cls, bentoml_cattr, _cattrs_forbid_extra_keys=True, max_new_tokens=override(rename="max_tokens") - ), -) + lambda cls: make_dict_unstructure_fn(cls, bentoml_cattr, _cattrs_omit_if_default=False, _cattrs_use_linecache=True, + **{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)})) +bentoml_cattr.register_structure_hook_factory(lambda cls: attr.has(cls) and lenient_issubclass(cls, SamplingParams), lambda cls: make_dict_structure_fn(cls, bentoml_cattr, _cattrs_forbid_extra_keys=True, max_new_tokens=override(rename="max_tokens"))) # cached it here to save one lookup per assignment _object_getattribute = object.__getattribute__ - class ModelSettings(t.TypedDict, total=False): - """ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__. + """ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__. - Note that all fields from this dictionary will then be converted to __openllm_*__ fields in LLMConfig. + Note that all fields from this dictionary will then be converted to __openllm_*__ fields in LLMConfig. - If the field below changes, make sure to run ./tools/update-config-stubs.py to generate correct __getitem__ - stubs for type-checking purposes. - """ + If the field below changes, make sure to run ./tools/update-config-stubs.py to generate correct __getitem__ + stubs for type-checking purposes. + """ - # NOTE: These required fields should be at the top, as it will be kw_only - default_id: Required[str] - model_ids: Required[ListStr] - architecture: Required[str] + # NOTE: These required fields should be at the top, as it will be kw_only + default_id: Required[str] + model_ids: Required[ListStr] + architecture: Required[str] - # default OpenLLM runtime imlementation - default_implementation: NotRequired[t.Dict[LiteralResourceSpec, LiteralRuntime]] + # default OpenLLM runtime imlementation + default_implementation: NotRequired[t.Dict[LiteralResourceSpec, LiteralRuntime]] - # meta + # meta + url: str + requires_gpu: bool + trust_remote_code: bool + service_name: NotRequired[str] + requirements: t.Optional[ListStr] + + # llm implementation specifics + bettertransformer: bool + model_type: t.Literal["causal_lm", "seq2seq_lm"] + runtime: t.Literal["transformers", "ggml"] + + # naming convention, only name_type is needed to infer from the class + # as the three below it can be determined automatically + name_type: NotRequired[t.Optional[t.Literal["dasherize", "lowercase"]]] + model_name: NotRequired[str] + start_name: NotRequired[str] + env: NotRequired[openllm.utils.EnvVarMixin] + # serving configuration + timeout: int + workers_per_resource: t.Union[int, float] + + # the target generation_config class to be used. + fine_tune_strategies: t.Tuple[t.Dict[str, t.Any], ...] + + # tokenizer_class is the custom tokenizer class for this given LLM + tokenizer_class: t.Optional[str] + +_transformed_type: DictStrAny = {"fine_tune_strategies": t.Dict[AdapterType, FineTuneConfig], "default_implementation": t.Dict[LiteralResourceSpec, LiteralRuntime]} + +@attr.define( + frozen=False, slots=True, field_transformer=lambda _, __: [ + attr.Attribute.from_counting_attr(k, dantic.Field(kw_only=False if t.get_origin(ann) is not Required else True, auto_default=True, use_default_converter=False, + type=_transformed_type.get(k, ann), metadata={"target": f"__openllm_{k}__"}, description=f"ModelSettings field for {k}.")) for k, ann in t.get_type_hints(ModelSettings).items() + ]) +class _ModelSettingsAttr: + """Internal attrs representation of ModelSettings.""" + def __getitem__(self, key: str) -> t.Any: + if key in codegen.get_annotations(ModelSettings): + return _object_getattribute(self, key) + raise KeyError(key) + + @classmethod + def default(cls) -> _ModelSettingsAttr: + return cls(**t.cast(DictStrAny, ModelSettings(default_id="__default__", model_ids=["__default__"], architecture="PreTrainedModel", default_implementation={"cpu": "pt", "nvidia.com/gpu": "pt"}, + name_type="dasherize", requires_gpu=False, url="", model_type="causal_lm", trust_remote_code=False, requirements=None, tokenizer_class=None, timeout=int(36e6), + service_name="", workers_per_resource=1., runtime="transformers"))) + + # NOTE: The below are dynamically generated by the field_transformer + if t.TYPE_CHECKING: + # update-config-stubs.py: attrs start + default_id: str + model_ids: ListStr + architecture: str + default_implementation: t.Dict[LiteralResourceSpec, LiteralRuntime] url: str requires_gpu: bool trust_remote_code: bool - service_name: NotRequired[str] + service_name: str requirements: t.Optional[ListStr] - - # llm implementation specifics bettertransformer: bool model_type: t.Literal["causal_lm", "seq2seq_lm"] runtime: t.Literal["transformers", "ggml"] - - # naming convention, only name_type is needed to infer from the class - # as the three below it can be determined automatically - name_type: NotRequired[t.Optional[t.Literal["dasherize", "lowercase"]]] - model_name: NotRequired[str] - start_name: NotRequired[str] - env: NotRequired[openllm.utils.EnvVarMixin] - # serving configuration + name_type: t.Optional[t.Literal["dasherize", "lowercase"]] + model_name: str + start_name: str + env: openllm.utils.EnvVarMixin timeout: int workers_per_resource: t.Union[int, float] - - # the target generation_config class to be used. - fine_tune_strategies: t.Tuple[t.Dict[str, t.Any], ...] - - # tokenizer_class is the custom tokenizer class for this given LLM + fine_tune_strategies: t.Dict[AdapterType, FineTuneConfig] tokenizer_class: t.Optional[str] - - -_transformed_type: DictStrAny = { - "fine_tune_strategies": t.Dict[AdapterType, FineTuneConfig], - "default_implementation": t.Dict[LiteralResourceSpec, LiteralRuntime], -} - - -@attr.define( - slots=True, - field_transformer=lambda _, __: [ - attr.Attribute.from_counting_attr( - k, - dantic.Field( - kw_only=False if t.get_origin(ann) is not Required else True, - auto_default=True, - use_default_converter=False, - type=_transformed_type.get(k, ann), - metadata={"target": f"__openllm_{k}__"}, - description=f"ModelSettings field for {k}.", - ), - ) - for k, ann in t.get_type_hints(ModelSettings).items() - ], - frozen=False, -) -class _ModelSettingsAttr: - """Internal attrs representation of ModelSettings.""" - - def __getitem__(self, key: str) -> t.Any: - if key in codegen.get_annotations(ModelSettings): - return _object_getattribute(self, key) - raise KeyError(key) - - @classmethod - def default(cls) -> _ModelSettingsAttr: - return cls( - **t.cast( - DictStrAny, - ModelSettings( - default_id="__default__", - model_ids=["__default__"], - architecture="PreTrainedModel", - default_implementation={"cpu":"pt", "nvidia.com/gpu": "pt"}, - name_type="dasherize", - requires_gpu=False, - url="", - model_type="causal_lm", - trust_remote_code=False, - requirements=None, - tokenizer_class=None, - timeout=int(36e6), - service_name="", - workers_per_resource=1., - runtime="transformers", - ), - ) - ) - - # NOTE: The below are dynamically generated by the field_transformer - if t.TYPE_CHECKING: - # update-config-stubs.py: attrs start - default_id: str - model_ids: ListStr - architecture: str - default_implementation: t.Dict[LiteralResourceSpec, LiteralRuntime] - url: str - requires_gpu: bool - trust_remote_code: bool - service_name: str - requirements: t.Optional[ListStr] - bettertransformer: bool - model_type: t.Literal["causal_lm", "seq2seq_lm"] - runtime: t.Literal["transformers", "ggml"] - name_type: t.Optional[t.Literal["dasherize", "lowercase"]] - model_name: str - start_name: str - env: openllm.utils.EnvVarMixin - timeout: int - workers_per_resource: t.Union[int, float] - fine_tune_strategies: t.Dict[AdapterType, FineTuneConfig] - tokenizer_class: t.Optional[str] - # update-config-stubs.py: attrs stop + # update-config-stubs.py: attrs stop # a heuristic cascading implementation resolver based on available resources def get_default_implementation(default_implementation_mapping: dict[LiteralResourceSpec, LiteralRuntime]) -> LiteralRuntime: - available_spec = available_resource_spec() - if resource_spec("tpu") in available_spec: return default_implementation_mapping.get(resource_spec("tpu"), "pt") - elif resource_spec("amd") in available_spec: return default_implementation_mapping.get(resource_spec("amd"), "pt") - elif resource_spec("nvidia") in available_spec: return default_implementation_mapping.get(resource_spec("nvidia"), "pt") - else: return default_implementation_mapping.get(resource_spec("cpu"), "pt") + available_spec = available_resource_spec() + if resource_spec("tpu") in available_spec: return default_implementation_mapping.get(resource_spec("tpu"), "pt") + elif resource_spec("amd") in available_spec: return default_implementation_mapping.get(resource_spec("amd"), "pt") + elif resource_spec("nvidia") in available_spec: return default_implementation_mapping.get(resource_spec("nvidia"), "pt") + else: return default_implementation_mapping.get(resource_spec("cpu"), "pt") def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]) -> _ModelSettingsAttr: - if "generation_class" in cl_.__config__: raise ValueError(f"'generation_class' shouldn't be defined in '__config__', rather defining all required attributes under '{cl_}.GenerationConfig' instead.") + if "generation_class" in cl_.__config__: raise ValueError(f"'generation_class' shouldn't be defined in '__config__', rather defining all required attributes under '{cl_}.GenerationConfig' instead.") - required_fields = {k for k, ann in t.get_type_hints(ModelSettings).items() if t.get_origin(ann) is Required} - if any(i not in cl_.__config__ for i in required_fields): raise ValueError(f"Missing required fields {required_fields} '__config__'.") + required_fields = {k for k, ann in t.get_type_hints(ModelSettings).items() if t.get_origin(ann) is Required} + if any(i not in cl_.__config__ for i in required_fields): raise ValueError(f"Missing required fields {required_fields} '__config__'.") - _cl_name = cl_.__name__.replace("Config", "") - _settings_attr = cls.default() - has_custom_name = all(i in cl_.__config__ for i in {"model_name", "start_name"}) + _cl_name = cl_.__name__.replace("Config", "") + _settings_attr = cls.default() + has_custom_name = all(i in cl_.__config__ for i in {"model_name", "start_name"}) - _settings_attr = attr.evolve(_settings_attr, **cl_.__config__) - _final_value_dct: DictStrAny = {} + _settings_attr = attr.evolve(_settings_attr, **cl_.__config__) + _final_value_dct: DictStrAny = {} - if not has_custom_name: - _final_value_dct["model_name"] = inflection.underscore(_cl_name) if _settings_attr["name_type"] == "dasherize" else _cl_name.lower() - _final_value_dct["start_name"] = inflection.dasherize(_final_value_dct["model_name"]) if _settings_attr["name_type"] == "dasherize" else _final_value_dct["model_name"] + if not has_custom_name: + _final_value_dct["model_name"] = inflection.underscore(_cl_name) if _settings_attr["name_type"] == "dasherize" else _cl_name.lower() + _final_value_dct["start_name"] = inflection.dasherize(_final_value_dct["model_name"]) if _settings_attr["name_type"] == "dasherize" else _final_value_dct["model_name"] - model_name = _final_value_dct["model_name"] if "model_name" in _final_value_dct else _settings_attr.model_name - # if the default implementation dependencies doesn't exist, then always fallback to 'pt' - default_implementation = _settings_attr.default_implementation - for rs, runtime in default_implementation.items(): - library_stub = "torch" if runtime == "pt" else runtime - if not BACKENDS_MAPPING[library_stub][0](): default_implementation[rs] = "pt" - _final_value_dct["default_implementation"] = default_implementation + model_name = _final_value_dct["model_name"] if "model_name" in _final_value_dct else _settings_attr.model_name + # if the default implementation dependencies doesn't exist, then always fallback to 'pt' + default_implementation = _settings_attr.default_implementation + for rs, runtime in default_implementation.items(): + library_stub = "torch" if runtime == "pt" else runtime + if not BACKENDS_MAPPING[library_stub][0](): default_implementation[rs] = "pt" + _final_value_dct["default_implementation"] = default_implementation - env = openllm.utils.EnvVarMixin(model_name, get_default_implementation(default_implementation), model_id=_settings_attr.default_id,bettertransformer=_settings_attr.bettertransformer) - _final_value_dct["env"] = env + env = openllm.utils.EnvVarMixin(model_name, get_default_implementation(default_implementation), model_id=_settings_attr.default_id, bettertransformer=_settings_attr.bettertransformer) + _final_value_dct["env"] = env - # bettertransformer support - if _settings_attr["bettertransformer"] is None: _final_value_dct["bettertransformer"] = str(env.bettertransformer_value).upper() in ENV_VARS_TRUE_VALUES - # if requires_gpu is True, then disable BetterTransformer for quantization. - if _settings_attr["requires_gpu"]: _final_value_dct["bettertransformer"] = False - _final_value_dct["service_name"] = f"generated_{model_name}_service.py" + # bettertransformer support + if _settings_attr["bettertransformer"] is None: _final_value_dct["bettertransformer"] = str(env.bettertransformer_value).upper() in ENV_VARS_TRUE_VALUES + # if requires_gpu is True, then disable BetterTransformer for quantization. + if _settings_attr["requires_gpu"]: _final_value_dct["bettertransformer"] = False + _final_value_dct["service_name"] = f"generated_{model_name}_service.py" - # NOTE: The key for fine-tune strategies is 'fine_tune_strategies' - _fine_tune_strategies: tuple[dict[str, t.Any], ...] | None = getattr(_settings_attr, "fine_tune_strategies", None) - _converted: dict[AdapterType, FineTuneConfig] = {} - if _fine_tune_strategies is not None: - # the given value is a tuple[dict[str, t.Any] ,...] - for _possible_ft_config in _fine_tune_strategies: - _adapter_type: AdapterType | None = _possible_ft_config.pop("adapter_type", None) - if _adapter_type is None: raise RuntimeError("'adapter_type' is required under config definition (currently missing)'.") - _llm_config_class = _possible_ft_config.pop("llm_config_class", cl_) - _doc = _possible_ft_config.pop("docs", f"Default {inflection.camelize(_adapter_type)}Config for {model_name}") - _converted[_adapter_type] = codegen.add_method_dunders( - cl_, - FineTuneConfig.make_adapter_config_class(_adapter_type, _llm_config_class, docs=_doc, **_possible_ft_config), - )() - _final_value_dct["fine_tune_strategies"] = _converted - - return attr.evolve(_settings_attr, **_final_value_dct) + # NOTE: The key for fine-tune strategies is 'fine_tune_strategies' + _fine_tune_strategies: tuple[dict[str, t.Any], ...] | None = getattr(_settings_attr, "fine_tune_strategies", None) + _converted: dict[AdapterType, FineTuneConfig] = {} + if _fine_tune_strategies is not None: + # the given value is a tuple[dict[str, t.Any] ,...] + for _possible_ft_config in _fine_tune_strategies: + _adapter_type: AdapterType | None = _possible_ft_config.pop("adapter_type", None) + if _adapter_type is None: raise RuntimeError("'adapter_type' is required under config definition (currently missing)'.") + _llm_config_class = _possible_ft_config.pop("llm_config_class", cl_) + _doc = _possible_ft_config.pop("docs", f"Default {inflection.camelize(_adapter_type)}Config for {model_name}") + _converted[_adapter_type] = codegen.add_method_dunders(cl_, FineTuneConfig.make_adapter_config_class(_adapter_type, _llm_config_class, docs=_doc, **_possible_ft_config))() + _final_value_dct["fine_tune_strategies"] = _converted + return attr.evolve(_settings_attr, **_final_value_dct) bentoml_cattr.register_structure_hook(_ModelSettingsAttr, structure_settings) - def _setattr_class(attr_name: str, value_var: t.Any) -> str: - """Use the builtin setattr to set *attr_name* to *value_var*. - - We can't use the cached object.__setattr__ since we are setting - attributes to a class. - - If add_dunder to True, the generated globs should include a __add_dunder - value that will be used to add the dunder methods to the class for given - value_var - """ - return f"setattr(cls, '{attr_name}', {value_var})" + """Use the builtin setattr to set *attr_name* to *value_var*. + We can't use the cached object.__setattr__ since we are setting attributes to a class. + If add_dunder to True, the generated globs should include a __add_dunder value that will be used to add the dunder methods to the class for given value_var + """ + return f"setattr(cls, '{attr_name}', {value_var})" def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: t.LiteralString = "openllm") -> t.Callable[..., None]: - """Generate the assignment script with prefix attributes __openllm___.""" - args: ListStr = [] - globs: DictStrAny = { - "cls": cls, - "_cached_attribute": attributes, - "_cached_getattribute_get": _object_getattribute.__get__, - } - annotations: DictStrAny = {"return": None} + """Generate the assignment script with prefix attributes __openllm___.""" + args: ListStr = [] + globs: DictStrAny = {"cls": cls, "_cached_attribute": attributes, "_cached_getattribute_get": _object_getattribute.__get__} + annotations: DictStrAny = {"return": None} - lines: ListStr = [] - for attr_name, field in attr.fields_dict(attributes.__class__).items(): - arg_name = field.metadata.get("target", f"__{_prefix}_{inflection.underscore(attr_name)}__") - args.append(f"{attr_name}=getattr(_cached_attribute, '{attr_name}')") - lines.append(_setattr_class(arg_name, attr_name)) - annotations[attr_name] = field.type - - return codegen.generate_function( - cls, "__assign_attr", lines, args=("cls", *args), globs=globs, annotations=annotations - ) + lines: ListStr = [] + for attr_name, field in attr.fields_dict(attributes.__class__).items(): + arg_name = field.metadata.get("target", f"__{_prefix}_{inflection.underscore(attr_name)}__") + args.append(f"{attr_name}=getattr(_cached_attribute, '{attr_name}')") + lines.append(_setattr_class(arg_name, attr_name)) + annotations[attr_name] = field.type + return codegen.generate_function(cls, "__assign_attr", lines, args=("cls", *args), globs=globs, annotations=annotations) _reserved_namespace = {"__config__", "GenerationConfig", "SamplingParams"} - @dataclass_transform(kw_only_default=True, order_default=True, field_specifiers=(attr.field, dantic.Field)) def llm_config_transform(cls: type[LLMConfig]) -> type[LLMConfig]: - non_intrusive_setattr( - cls, - "__dataclass_transform__", - { - "order_default": True, - "kw_only_default": True, - "field_specifiers": (attr.field, dantic.Field), - }, - ) - return cls - + non_intrusive_setattr(cls, "__dataclass_transform__", {"order_default": True, "kw_only_default": True, "field_specifiers": (attr.field, dantic.Field)}) + return cls @attr.define(slots=True) class _ConfigAttr: - Field = dantic.Field - """Field is a alias to the internal dantic utilities to easily create + Field = dantic.Field + """Field is a alias to the internal dantic utilities to easily create attrs.fields with pydantic-compatible interface. For example: ```python class MyModelConfig(openllm.LLMConfig): - field1 = openllm.LLMConfig.Field(...) ``` """ - # NOTE: The following is handled via __init_subclass__, and is only used for TYPE_CHECKING - if t.TYPE_CHECKING: - # NOTE: public attributes to override - __config__: ModelSettings = Field(None) - """Internal configuration for this LLM model. Each of the field in here will be populated + # NOTE: The following is handled via __init_subclass__, and is only used for TYPE_CHECKING + if t.TYPE_CHECKING: + # NOTE: public attributes to override + __config__: ModelSettings = Field(None) + """Internal configuration for this LLM model. Each of the field in here will be populated and prefixed with __openllm___""" - GenerationConfig: object = Field(None) - """Users can override this subclass of any given LLMConfig to provide GenerationConfig + GenerationConfig: object = Field(None) + """Users can override this subclass of any given LLMConfig to provide GenerationConfig default value. For example: ```python @@ -1000,8 +705,8 @@ class _ConfigAttr: eos_token_id: int = 11 ``` """ - SamplingParams: object = Field(None) - """Users can override this subclass of any given LLMConfig to provide SamplingParams + SamplingParams: object = Field(None) + """Users can override this subclass of any given LLMConfig to provide SamplingParams default value. For example: ```python @@ -1013,44 +718,43 @@ class _ConfigAttr: eos_token_id: int = 11 ``` """ - # NOTE: Internal attributes that should only be used by OpenLLM. Users usually shouldn't - # concern any of these. These are here for pyright not to complain. - __attrs_attrs__: tuple[attr.Attribute[t.Any], ...] = Field(None, init=False) - """Since we are writing our own __init_subclass__, which is an alternative way for __prepare__, + # NOTE: Internal attributes that should only be used by OpenLLM. Users usually shouldn't + # concern any of these. These are here for pyright not to complain. + __attrs_attrs__: tuple[attr.Attribute[t.Any], ...] = Field(None, init=False) + """Since we are writing our own __init_subclass__, which is an alternative way for __prepare__, we want openllm.LLMConfig to be attrs-like dataclass that has pydantic-like interface. __attrs_attrs__ will be handled dynamically by __init_subclass__. """ - __openllm_hints__: DictStrAny = Field(None, init=False) - """An internal cache of resolved types for this LLMConfig.""" - __openllm_accepted_keys__: set[str] = Field(None, init=False) - """The accepted keys for this LLMConfig.""" - __openllm_extras__: DictStrAny = Field(None, init=False) - """Extra metadata for this LLMConfig.""" - __openllm_generation_class__: type[openllm._configuration.GenerationConfig] = Field(None) - """The result generated GenerationConfig class for this LLMConfig. This will be used + __openllm_hints__: DictStrAny = Field(None, init=False) + """An internal cache of resolved types for this LLMConfig.""" + __openllm_accepted_keys__: set[str] = Field(None, init=False) + """The accepted keys for this LLMConfig.""" + __openllm_extras__: DictStrAny = Field(None, init=False) + """Extra metadata for this LLMConfig.""" + __openllm_generation_class__: type[openllm._configuration.GenerationConfig] = Field(None) + """The result generated GenerationConfig class for this LLMConfig. This will be used to create the generation_config argument that can be used throughout the lifecycle. This class will also be managed internally by OpenLLM.""" - __openllm_sampling_class__: type[openllm._configuration.SamplingParams] = Field(None) - """The result generated SamplingParams class for this LLMConfig. This will be used + __openllm_sampling_class__: type[openllm._configuration.SamplingParams] = Field(None) + """The result generated SamplingParams class for this LLMConfig. This will be used to create arguments for vLLM LLMEngine that can be used throughout the lifecycle. This class will also be managed internally by OpenLLM.""" + def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: + """Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract.""" - def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: - """Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract.""" + # NOTE: The following will be populated from __config__ and also + # considered to be public API. Users can also access these via self[key] + # To update the docstring for these field, update it through tools/update-config-stubs.py - # NOTE: The following will be populated from __config__ and also - # considered to be public API. Users can also access these via self[key] - # To update the docstring for these field, update it through tools/update-config-stubs.py - - # update-config-stubs.py: special start - __openllm_default_id__: str = Field(None) - """Return the default model to use when using 'openllm start '. + # update-config-stubs.py: special start + __openllm_default_id__: str = Field(None) + """Return the default model to use when using 'openllm start '. This could be one of the keys in 'self.model_ids' or custom users model. This field is required when defining under '__config__'. """ - __openllm_model_ids__: ListStr = Field(None) - """A list of supported pretrained models tag for this given runnable. + __openllm_model_ids__: ListStr = Field(None) + """A list of supported pretrained models tag for this given runnable. For example: For FLAN-T5 impl, this would be ["google/flan-t5-small", "google/flan-t5-base", @@ -1058,8 +762,8 @@ class _ConfigAttr: This field is required when defining under '__config__'. """ - __openllm_architecture__: str = Field(None) - """The model architecture that is supported by this LLM. + __openllm_architecture__: str = Field(None) + """The model architecture that is supported by this LLM. Note that any model weights within this architecture generation can always be run and supported by this LLM. @@ -1069,44 +773,44 @@ class _ConfigAttr: ```bash openllm start gpt-neox --model-id stabilityai/stablelm-tuned-alpha-3b ```""" - __openllm_default_implementation__: t.Dict[LiteralResourceSpec, LiteralRuntime] = Field(None) - """The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`. + __openllm_default_implementation__: t.Dict[LiteralResourceSpec, LiteralRuntime] = Field(None) + """The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`. - It is a dictionary of key as the accelerator spec in k8s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM Runtime ('flax', 'tf', 'pt', 'vllm') + It is a dictionary of key as the accelerator spec in k4s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM Runtime ('flax', 'tf', 'pt', 'vllm') """ - __openllm_url__: str = Field(None) - """The resolved url for this LLMConfig.""" - __openllm_requires_gpu__: bool = Field(None) - """Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.""" - __openllm_trust_remote_code__: bool = Field(None) - """Whether to always trust remote code""" - __openllm_service_name__: str = Field(None) - """Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""" - __openllm_requirements__: t.Optional[ListStr] = Field(None) - """The default PyPI requirements needed to run this given LLM. By default, we will depend on + __openllm_url__: str = Field(None) + """The resolved url for this LLMConfig.""" + __openllm_requires_gpu__: bool = Field(None) + """Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.""" + __openllm_trust_remote_code__: bool = Field(None) + """Whether to always trust remote code""" + __openllm_service_name__: str = Field(None) + """Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""" + __openllm_requirements__: t.Optional[ListStr] = Field(None) + """The default PyPI requirements needed to run this given LLM. By default, we will depend on bentoml, torch, transformers.""" - __openllm_bettertransformer__: bool = Field(None) - """Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.""" - __openllm_model_type__: t.Literal["causal_lm", "seq2seq_lm"] = Field(None) - """The model type for this given LLM. By default, it should be causal language modeling. + __openllm_bettertransformer__: bool = Field(None) + """Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.""" + __openllm_model_type__: t.Literal["causal_lm", "seq2seq_lm"] = Field(None) + """The model type for this given LLM. By default, it should be causal language modeling. Currently supported 'causal_lm' or 'seq2seq_lm' """ - __openllm_runtime__: t.Literal["transformers", "ggml"] = Field(None) - """The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.""" - __openllm_name_type__: t.Optional[t.Literal["dasherize", "lowercase"]] = Field(None) - """The default name typed for this model. "dasherize" will convert the name to lowercase and + __openllm_runtime__: t.Literal["transformers", "ggml"] = Field(None) + """The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.""" + __openllm_name_type__: t.Optional[t.Literal["dasherize", "lowercase"]] = Field(None) + """The default name typed for this model. "dasherize" will convert the name to lowercase and replace spaces with dashes. "lowercase" will convert the name to lowercase. If this is not set, then both `model_name` and `start_name` must be specified.""" - __openllm_model_name__: str = Field(None) - """The normalized version of __openllm_start_name__, determined by __openllm_name_type__""" - __openllm_start_name__: str = Field(None) - """Default name to be used with `openllm start`""" - __openllm_env__: openllm.utils.EnvVarMixin = Field(None) - """A EnvVarMixin instance for this LLMConfig.""" - __openllm_timeout__: int = Field(None) - """The default timeout to be set for this given LLM.""" - __openllm_workers_per_resource__: t.Union[int, float] = Field(None) - """The number of workers per resource. This is used to determine the number of workers to use for this model. + __openllm_model_name__: str = Field(None) + """The normalized version of __openllm_start_name__, determined by __openllm_name_type__""" + __openllm_start_name__: str = Field(None) + """Default name to be used with `openllm start`""" + __openllm_env__: openllm.utils.EnvVarMixin = Field(None) + """A EnvVarMixin instance for this LLMConfig.""" + __openllm_timeout__: int = Field(None) + """The default timeout to be set for this given LLM.""" + __openllm_workers_per_resource__: t.Union[int, float] = Field(None) + """The number of workers per resource. This is used to determine the number of workers to use for this model. For example, if this is set to 0.5, then OpenLLM will use 1 worker per 2 resources. If this is set to 1, then OpenLLM will use 1 worker per resource. If this is set to 2, then OpenLLM will use 2 workers per resource. @@ -1115,794 +819,671 @@ class _ConfigAttr: By default, it is set to 1. """ - __openllm_fine_tune_strategies__: t.Dict[AdapterType, FineTuneConfig] = Field(None) - """The fine-tune strategies for this given LLM.""" - __openllm_tokenizer_class__: t.Optional[str] = Field(None) - """Optional tokenizer class for this given LLM. See Llama for example.""" - # update-config-stubs.py: special stop - + __openllm_fine_tune_strategies__: t.Dict[AdapterType, FineTuneConfig] = Field(None) + """The fine-tune strategies for this given LLM.""" + __openllm_tokenizer_class__: t.Optional[str] = Field(None) + """Optional tokenizer class for this given LLM. See Llama for example.""" + # update-config-stubs.py: special stop class _ConfigBuilder: - """A modified version of attrs internal _ClassBuilder, and should only be called within __init_subclass__ of LLMConfig. + """A modified version of attrs internal _ClassBuilder, and should only be called within __init_subclass__ of LLMConfig. - Where: - - has_custom_setattr=True - - getstate_setstate=None (config class will always be a slotted class.) - - slots=True - - auto_attribs=False (We should handle it before _ConfigBuilder is invoked) - - cache_hash=False (We don't need to cache the hash code of this object for now.) - - collect_by_mro=True (The correct behaviour to resolve inheritance) - - field_transformer=codegen.make_env_transformer (We need to transform the field to have env variable) + Where: + - has_custom_setattr=True + - getstate_setstate=None (config class will always be a slotted class.) + - slots=True + - auto_attribs=False (We should handle it before _ConfigBuilder is invoked) + - cache_hash=False (We don't need to cache the hash code of this object for now.) + - collect_by_mro=True (The correct behaviour to resolve inheritance) + - field_transformer=codegen.make_env_transformer (We need to transform the field to have env variable) - It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__ + It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__ + """ + + __slots__ = ("_cls", "_cls_dict", "_attr_names", "_attrs", "_model_name", "_base_attr_map", "_base_names", "_has_pre_init", "_has_post_init") + + def __init__(self, cls: type[LLMConfig], these: dict[str, _CountingAttr[t.Any]], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True): + attrs, base_attrs, base_attr_map = _transform_attrs(cls, these, auto_attribs, kw_only, collect_by_mro, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__)) + self._cls, self._model_name, self._cls_dict, self._attrs, self._base_names, self._base_attr_map = cls, cls.__openllm_model_name__, dict(cls.__dict__), attrs, {a.name for a in base_attrs}, base_attr_map + self._attr_names = tuple(a.name for a in attrs) + self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False)) + self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) + self._cls_dict["__attrs_attrs__"] = self._attrs + + def build_class(self) -> type[LLMConfig]: + """Finalize class based on the accumulated configuration. + + Builder cannot be used after calling this method. + + > A difference between this and attrs._ClassBuilder is that we don't + > create a new class after constructing all __dict__. This has to do + > with recursive called within __init_subclass__ """ + cd = {k: v for k, v in self._cls_dict.items() if k not in (*tuple(self._attr_names), "__dict__", "__weakref__")} + # Traverse the MRO to collect existing slots + # and check for an existing __weakref__. + weakref_inherited = False + existing_slots: DictStrAny = {} + for base_cls in self._cls.__mro__[1:-1]: + if base_cls.__dict__.get("__weakref__", None) is not None: weakref_inherited = True + existing_slots.update({name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])}) - __slots__ = ("_cls", "_cls_dict", "_attr_names", "_attrs", "_model_name", "_base_attr_map", "_base_names", "_has_pre_init", "_has_post_init") - def __init__(self, cls: type[LLMConfig], these: dict[str, _CountingAttr[t.Any]], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True): - attrs, base_attrs, base_attr_map = _transform_attrs(cls, these, auto_attribs, kw_only, collect_by_mro, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__)) - self._cls = cls - self._model_name = cls.__openllm_model_name__ - self._cls_dict = dict(cls.__dict__) - self._attrs = attrs - self._base_names = {a.name for a in base_attrs} - self._base_attr_map = base_attr_map - self._attr_names = tuple(a.name for a in attrs) - self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False)) - self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) - self._cls_dict["__attrs_attrs__"] = self._attrs - def build_class(self) -> type[LLMConfig]: - """Finalize class based on the accumulated configuration. + names = self._attr_names + base_names = set(self._base_names) + if "__weakref__" not in getattr(self._cls, "__slots__", ()) and "__weakref__" not in names and not weakref_inherited: names += ("__weakref__",) - Builder cannot be used after calling this method. + # We only add the names of attributes that aren't inherited. + # Setting __slots__ to inherited attributes wastes memory. + slot_names = [name for name in names if name not in base_names] + # There are slots for attributes from current class + # that are defined in parent classes. + # As their descriptors may be overridden by a child class, + # we collect them here and update the class dict + reused_slots = {slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names} + # We only add the names of attributes that aren't inherited. + # Setting __slots__ to inherited attributes wastes memory. + # __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots + slot_names = [name for name in slot_names if name not in reused_slots] + cd.update(reused_slots) + cd["__slots__"] = tuple(slot_names) + cd["__qualname__"] = self._cls.__qualname__ - > A difference between this and attrs._ClassBuilder is that we don't - > create a new class after constructing all __dict__. This has to do - > with recursive called within __init_subclass__ - """ - cd = {k: v for k, v in self._cls_dict.items() if k not in (*tuple(self._attr_names), "__dict__", "__weakref__")} - # Traverse the MRO to collect existing slots - # and check for an existing __weakref__. - weakref_inherited = False - existing_slots: DictStrAny = {} - for base_cls in self._cls.__mro__[1:-1]: - if base_cls.__dict__.get("__weakref__", None) is not None: weakref_inherited = True - existing_slots.update({name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])}) + # We can only patch the class here, rather than instantiate + # a new one, since type.__new__ actually will invoke __init_subclass__ + # and since we use the _ConfigBuilder in __init_subclass__, it will + # raise recusion error. See https://peps.python.org/pep-0487/ for more + # information on how __init_subclass__ works. + for k, value in cd.items(): setattr(self._cls, k, value) - names = self._attr_names - base_names = set(self._base_names) - if "__weakref__" not in getattr(self._cls, "__slots__", ()) and "__weakref__" not in names and not weakref_inherited: names += ("__weakref__",) + return self.make_closure(self._cls) - # We only add the names of attributes that aren't inherited. - # Setting __slots__ to inherited attributes wastes memory. - slot_names = [name for name in names if name not in base_names] - # There are slots for attributes from current class - # that are defined in parent classes. - # As their descriptors may be overridden by a child class, - # we collect them here and update the class dict - reused_slots = { - slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names - } - # We only add the names of attributes that aren't inherited. - # Setting __slots__ to inherited attributes wastes memory. - # __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots - slot_names = [name for name in slot_names if name not in reused_slots] - cd.update(reused_slots) - cd["__slots__"] = tuple(slot_names) + def make_closure(self, cls: type[t.Any]) -> type[t.Any]: + # The following is a fix for + # . + # If a method mentions `__class__` or uses the no-arg super(), the + # compiler will bake a reference to the class in the method itself + # as `method.__closure__`. Since we replace the class with a + # clone, we rewrite these references so it keeps working. + for item in cls.__dict__.values(): + # Class- and staticmethods hide their functions inside. + # These might need to be rewritten as well. + if isinstance(item, (classmethod, staticmethod)): closure_cells = getattr(item.__func__, "__closure__", None) + # Workaround for property `super()` shortcut (PY3-only). + # There is no universal way for other descriptors. + elif isinstance(item, property): closure_cells = getattr(item.fget, "__closure__", None) + else: closure_cells = getattr(item, "__closure__", None) - cd["__qualname__"] = self._cls.__qualname__ + if not closure_cells: # Catch None or the empty list. + continue + for cell in closure_cells: + try: match = cell.cell_contents is self._cls + except ValueError: pass # ValueError: Cell is empty + else: + if match: set_closure_cell(cell, cls) - # We can only patch the class here, rather than instantiate - # a new one, since type.__new__ actually will invoke __init_subclass__ - # and since we use the _ConfigBuilder in __init_subclass__, it will - # raise recusion error. See https://peps.python.org/pep-0487/ for more - # information on how __init_subclass__ works. - for k, value in cd.items(): - setattr(self._cls, k, value) - - return self.make_closure(self._cls) - - def make_closure(self, cls: type[t.Any]) -> type[t.Any]: - # The following is a fix for - # . - # If a method mentions `__class__` or uses the no-arg super(), the - # compiler will bake a reference to the class in the method itself - # as `method.__closure__`. Since we replace the class with a - # clone, we rewrite these references so it keeps working. - for item in cls.__dict__.values(): - if isinstance(item, (classmethod, staticmethod)): - # Class- and staticmethods hide their functions inside. - # These might need to be rewritten as well. - closure_cells = getattr(item.__func__, "__closure__", None) - elif isinstance(item, property): - # Workaround for property `super()` shortcut (PY3-only). - # There is no universal way for other descriptors. - closure_cells = getattr(item.fget, "__closure__", None) - else: - closure_cells = getattr(item, "__closure__", None) - - if not closure_cells: # Catch None or the empty list. - continue - for cell in closure_cells: - try: - match = cell.cell_contents is self._cls - except ValueError: # ValueError: Cell is empty - pass - else: - if match: - set_closure_cell(cell, cls) - - return llm_config_transform(cls) - - def add_attrs_init(self) -> t.Self: - self._cls_dict["__attrs_init__"] = codegen.add_method_dunders( - self._cls, - _make_init( - self._cls, - self._attrs, - self._has_pre_init, - self._has_post_init, - False, # frozen - True, # slots - False, # cache_hash - self._base_attr_map, - False, # This is not an exception - None, # no on_setattr - True, - ), - ) - return self - - def add_repr(self) -> t.Self: - for key, fn in ReprMixin.__dict__.items(): - if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"): - self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn) - self._cls_dict["__repr_keys__"] = property( - lambda _: {i.name for i in self._attrs} | {"generation_config", "sampling_config"} - ) - return self + return llm_config_transform(cls) + def add_attrs_init(self) -> t.Self: + self._cls_dict["__attrs_init__"] = codegen.add_method_dunders(self._cls, _make_init(self._cls, self._attrs, self._has_pre_init, self._has_post_init, False, True, False, self._base_attr_map, False, None, True)) + return self + def add_repr(self) -> t.Self: + for key, fn in ReprMixin.__dict__.items(): + if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"): self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn) + self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config", "sampling_config"}) + return self @attr.define(slots=True, init=False) class LLMConfig(_ConfigAttr): - """``openllm.LLMConfig`` is a pydantic-like ``attrs`` interface that offers fast and easy-to-use APIs. + """``openllm.LLMConfig`` is a pydantic-like ``attrs`` interface that offers fast and easy-to-use APIs. - It lives in between the nice UX of `pydantic` and fast performance of `attrs` where it allows users to quickly formulate - a LLMConfig for any LLM without worrying too much about performance. It does a few things: + It lives in between the nice UX of `pydantic` and fast performance of `attrs` where it allows users to quickly formulate + a LLMConfig for any LLM without worrying too much about performance. It does a few things: - - Automatic environment conversion: Each fields will automatically be provisioned with an environment - variable, make it easy to work with ahead-of-time or during serving time - - Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API, - i.e: ``model_construct_env``, ``to_generation_config``, ``to_click_options`` - - Automatic CLI generation: It can identify each fields and convert it to compatible Click options. - This means developers can use any of the LLMConfig to create CLI with compatible-Python - CLI library (click, typer, ...) + - Automatic environment conversion: Each fields will automatically be provisioned with an environment + variable, make it easy to work with ahead-of-time or during serving time + - Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API, + i.e: ``model_construct_env``, ``to_generation_config``, ``to_click_options`` + - Automatic CLI generation: It can identify each fields and convert it to compatible Click options. + This means developers can use any of the LLMConfig to create CLI with compatible-Python + CLI library (click, typer, ...) - > Internally, LLMConfig is an attrs class. All subclass of LLMConfig contains "attrs-like" features, - > which means LLMConfig will actually generate subclass to have attrs-compatible API, so that the subclass - > can be written as any normal Python class. + > Internally, LLMConfig is an attrs class. All subclass of LLMConfig contains "attrs-like" features, + > which means LLMConfig will actually generate subclass to have attrs-compatible API, so that the subclass + > can be written as any normal Python class. - To directly configure GenerationConfig for any given LLM, create a GenerationConfig under the subclass: + To directly configure GenerationConfig for any given LLM, create a GenerationConfig under the subclass: - ```python - class FlanT5Config(openllm.LLMConfig): - class GenerationConfig: - temperature: float = 0.75 - max_new_tokens: int = 3000 - top_k: int = 50 - top_p: float = 0.4 - repetition_penalty = 1.0 - ``` - By doing so, openllm.LLMConfig will create a compatible GenerationConfig attrs class that can be converted - to ``transformers.GenerationConfig``. These attribute can be accessed via ``LLMConfig.generation_config``. + ```python + class FlanT5Config(openllm.LLMConfig): + class GenerationConfig: + temperature: float = 0.75 + max_new_tokens: int = 3000 + top_k: int = 50 + top_p: float = 0.4 + repetition_penalty = 1.0 + ``` + By doing so, openllm.LLMConfig will create a compatible GenerationConfig attrs class that can be converted + to ``transformers.GenerationConfig``. These attribute can be accessed via ``LLMConfig.generation_config``. - By default, all LLMConfig must provide a __config__ with 'default_id' and 'model_ids'. + By default, all LLMConfig must provide a __config__ with 'default_id' and 'model_ids'. - All other fields are optional, and will be use default value if not set. + All other fields are optional, and will be use default value if not set. - ```python - class FalconConfig(openllm.LLMConfig): - __config__ = { - "name_type": "lowercase", - "trust_remote_code": True, - "requires_gpu": True, - "timeout": 3600000, - "url": "https://falconllm.tii.ae/", - "requirements": ["einops", "xformers", "safetensors"], - # NOTE: The below are always required - "default_id": "tiiuae/falcon-7b", - "model_ids": [ - "tiiuae/falcon-7b", - "tiiuae/falcon-40b", - "tiiuae/falcon-7b-instruct", - "tiiuae/falcon-40b-instruct", - ], - } - ``` + ```python + class FalconConfig(openllm.LLMConfig): + __config__ = { + "name_type": "lowercase", + "trust_remote_code": True, + "requires_gpu": True, + "timeout": 3600000, + "url": "https://falconllm.tii.ae/", + "requirements": ["einops", "xformers", "safetensors"], + # NOTE: The below are always required + "default_id": "tiiuae/falcon-7b", + "model_ids": [ + "tiiuae/falcon-7b", + "tiiuae/falcon-40b", + "tiiuae/falcon-7b-instruct", + "tiiuae/falcon-40b-instruct", + ], + } + ``` - > **Changelog**: - > Since 0.1.7, one can also define given fine-tune strategies for given LLM via its config: - ```python - class OPTConfig(openllm.LLMConfig): - __config__ = { - "name_type": "lowercase", - "trust_remote_code": False, - "url": "https://huggingface.co/docs/transformers/model_doc/opt", - "default_id": "facebook/opt-1.3b", - "model_ids": [ - "facebook/opt-125m", - "facebook/opt-350m", - "facebook/opt-1.3b", - "facebook/opt-2.7b", - "facebook/opt-6.7b", - "facebook/opt-66b", - ], - "fine_tune_strategies": ( - { - "adapter_type": "lora", - "r": 16, - "lora_alpha": 32, - "target_modules": ["q_proj", "v_proj"], - "lora_dropout": 0.05, - "bias": "none", - }, - ), - } - ``` + > **Changelog**: + > Since 0.1.7, one can also define given fine-tune strategies for given LLM via its config: + ```python + class OPTConfig(openllm.LLMConfig): + __config__ = { + "name_type": "lowercase", + "trust_remote_code": False, + "url": "https://huggingface.co/docs/transformers/model_doc/opt", + "default_id": "facebook/opt-1.3b", + "model_ids": [ + "facebook/opt-125m", + "facebook/opt-350m", + "facebook/opt-1.3b", + "facebook/opt-2.7b", + "facebook/opt-6.7b", + "facebook/opt-66b", + ], + "fine_tune_strategies": ( + { + "adapter_type": "lora", + "r": 16, + "lora_alpha": 32, + "target_modules": ["q_proj", "v_proj"], + "lora_dropout": 0.05, + "bias": "none", + }, + ), + } + ``` - Future work: - - Support pydantic-core as validation backend. + Future work: + - Support pydantic-core as validation backend. + """ + @classmethod + def _make_subclass(cls, class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: t.LiteralString | None = None) -> type[At]: + camel_name = cls.__name__.replace("Config", "") + klass = attr.make_class( + f"{camel_name}{class_attr}", [], bases=(base,), slots=True, weakref_slot=True, frozen=True, repr=False, init=False, collect_by_mro=True, + field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__, suffix=suffix_env, globs=globs, + default_callback=lambda field_name, field_default: getattr(getattr(cls, class_attr), field_name, field_default) if codegen.has_own_attribute(cls, class_attr) else field_default)) + # For pickling to work, the __module__ variable needs to be set to the + # frame where the class is created. This respect the module that is created from cls + try: klass.__module__ = cls.__module__ + except (AttributeError, ValueError): pass + return t.cast("type[At]", klass) + + def __init_subclass__(cls: type[LLMConfig]): + """The purpose of this ``__init_subclass__`` is to offer pydantic UX while adhering to attrs contract. + + This means we will construct all fields and metadata and hack into + how attrs use some of the 'magic' construction to generate the fields. + + It also does a few more extra features: It also generate all __openllm_*__ config from + ModelSettings (derived from __config__) to the class. """ + if not cls.__name__.endswith("Config"): + logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__) + cls.__name__ = f"{cls.__name__}Config" - @classmethod - def _make_subclass( - cls, - class_attr: str, - base: type[At], - globs: dict[str, t.Any] | None = None, - suffix_env: t.LiteralString | None = None, - ) -> type[At]: - camel_name = cls.__name__.replace("Config", "") + if not hasattr(cls, "__config__"): raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.") - klass = attr.make_class( - f"{camel_name}{class_attr}", - [], - bases=(base,), - slots=True, - weakref_slot=True, - frozen=True, - repr=False, - init=False, - collect_by_mro=True, - field_transformer=codegen.make_env_transformer( - cls, - cls.__openllm_model_name__, - suffix=suffix_env, - default_callback=lambda field_name, field_default: getattr( - getattr(cls, class_attr), field_name, field_default - ) - if codegen.has_own_attribute(cls, class_attr) - else field_default, - globs=globs, - ), - ) + # auto assignment attributes generated from __config__ after create the new slot class. + _make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettingsAttr))(cls) + cls.__openllm_generation_class__ = cls._make_subclass("GenerationConfig", openllm._configuration.GenerationConfig, suffix_env="generation") + cls.__openllm_sampling_class__ = cls._make_subclass("SamplingParams", openllm._configuration.SamplingParams, suffix_env="sampling") - # For pickling to work, the __module__ variable needs to be set to the - # frame where the class is created. This respect the module that is created from cls - try: klass.__module__ = cls.__module__ - except (AttributeError, ValueError): pass - return t.cast("type[At]", klass) + # process a fields under cls.__dict__ and auto convert them with dantic.Field + # this is similar logic to attr._make._transform_attrs + cd = cls.__dict__ + anns = codegen.get_annotations(cls) + # _CountingAttr is the underlying representation of attr.field + ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)} + these: dict[str, _CountingAttr[t.Any]] = {} + annotated_names: set[str] = set() + for attr_name, typ in anns.items(): + if codegen.is_class_var(typ): continue + annotated_names.add(attr_name) + val = cd.get(attr_name, attr.NOTHING) + if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val): + if val is attr.NOTHING: val = cls.Field(env=field_env_key(cls.__openllm_model_name__, attr_name)) + else: val = cls.Field(default=val, env=field_env_key(cls.__openllm_model_name__, attr_name)) + these[attr_name] = val + unannotated = ca_names - annotated_names + if len(unannotated) > 0: + missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr[t.Any]", cd.get(n)).counter) + raise openllm.exceptions.MissingAnnotationAttributeError(f"The following field doesn't have a type annotation: {missing_annotated}") + # We need to set the accepted key before generation_config + # as generation_config is a special field that users shouldn't pass. + cls.__openllm_accepted_keys__ = set(these.keys()) | {a.name for a in attr.fields(cls.__openllm_generation_class__)} | {a.name for a in attr.fields(cls.__openllm_sampling_class__)} + cls = _ConfigBuilder(cls, these).add_attrs_init().add_repr().build_class() - def __init_subclass__(cls: type[LLMConfig]): - """The purpose of this ``__init_subclass__`` is to offer pydantic UX while adhering to attrs contract. + # Finally, resolve the types + if getattr(cls, "__attrs_types_resolved__", None) != cls: + # NOTE: We will try to resolve type here, and cached it for faster use + globs: DictStrAny = {"t": t, "typing": t, "Constraint": Constraint} + if cls.__module__ in sys.modules: globs.update(sys.modules[cls.__module__].__dict__) + attr.resolve_types(cls.__openllm_generation_class__, globalns=globs) + attr.resolve_types(cls.__openllm_sampling_class__, globalns=globs) + cls = attr.resolve_types(cls, globalns=globs) + # the hint cache for easier access + cls.__openllm_hints__ = {f.name: f.type for ite in [attr.fields(cls), attr.fields(cls.__openllm_generation_class__), attr.fields(cls.__openllm_sampling_class__),] for f in ite} - This means we will construct all fields and metadata and hack into - how attrs use some of the 'magic' construction to generate the fields. + # For pickling to work, the __module__ variable needs to be set to the + # frame where the class is created. Bypass this step in environments where + # sys._getframe is not defined (Jython for example) or sys._getframe is not + # defined for arguments greater than 0 (IronPython). + try: cls.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): pass - It also does a few more extra features: It also generate all __openllm_*__ config from - ModelSettings (derived from __config__) to the class. - """ - if not cls.__name__.endswith("Config"): - logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__) - cls.__name__ = f"{cls.__name__}Config" + def __setattr__(self, attr: str, value: t.Any) -> None: + if attr in _reserved_namespace: raise ForbiddenAttributeError(f"{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}.") + super().__setattr__(attr, value) - if not hasattr(cls, "__config__"): raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.") + def __init__(self, *, generation_config: DictStrAny | None = None, __openllm_extras__: DictStrAny | None = None, **attrs: t.Any): + # create a copy of the keys as cache + _cached_keys = tuple(attrs.keys()) - # auto assignment attributes generated from __config__ after create the new slot class. - _make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettingsAttr))(cls) - cls.__openllm_generation_class__ = cls._make_subclass("GenerationConfig", openllm._configuration.GenerationConfig, suffix_env="generation") - cls.__openllm_sampling_class__ = cls._make_subclass("SamplingParams", openllm._configuration.SamplingParams, suffix_env="sampling") + _generation_cl_dict = attr.fields_dict(self.__openllm_generation_class__) + if generation_config is None: generation_config = {k: v for k, v in attrs.items() if k in _generation_cl_dict} + else: generation_config = config_merger.merge(generation_config, {k: v for k, v in attrs.items() if k in _generation_cl_dict}) - # process a fields under cls.__dict__ and auto convert them with dantic.Field - # this is similar logic to attr._make._transform_attrs - cd = cls.__dict__ - anns = codegen.get_annotations(cls) - # _CountingAttr is the underlying representation of attr.field - ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)} - these: dict[str, _CountingAttr[t.Any]] = {} - annotated_names: set[str] = set() - for attr_name, typ in anns.items(): - if codegen.is_class_var(typ): continue - annotated_names.add(attr_name) - val = cd.get(attr_name, attr.NOTHING) - if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val): - if val is attr.NOTHING: val = cls.Field(env=field_env_key(cls.__openllm_model_name__, attr_name)) - else: val = cls.Field(default=val, env=field_env_key(cls.__openllm_model_name__, attr_name)) - these[attr_name] = val - unannotated = ca_names - annotated_names - if len(unannotated) > 0: - missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr[t.Any]", cd.get(n)).counter) - raise openllm.exceptions.MissingAnnotationAttributeError(f"The following field doesn't have a type annotation: {missing_annotated}") - # We need to set the accepted key before generation_config - # as generation_config is a special field that users shouldn't pass. - cls.__openllm_accepted_keys__ = set(these.keys()) | {a.name for a in attr.fields(cls.__openllm_generation_class__)} | {a.name for a in attr.fields(cls.__openllm_sampling_class__)} - cls = _ConfigBuilder(cls, these).add_attrs_init().add_repr().build_class() + sampling_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(self.__openllm_sampling_class__)} - # Finally, resolve the types - if getattr(cls, "__attrs_types_resolved__", None) != cls: - # NOTE: We will try to resolve type here, and cached it for faster use - globs: DictStrAny = {"t": t, "typing": t, "Constraint": Constraint} - if cls.__module__ in sys.modules: globs.update(sys.modules[cls.__module__].__dict__) - attr.resolve_types(cls.__openllm_generation_class__, globalns=globs) - attr.resolve_types(cls.__openllm_sampling_class__, globalns=globs) - cls = attr.resolve_types(cls, globalns=globs) - # the hint cache for easier access - cls.__openllm_hints__ = { - f.name: f.type - for ite in [ - attr.fields(cls), - attr.fields(cls.__openllm_generation_class__), - attr.fields(cls.__openllm_sampling_class__), - ] - for f in ite - } + for k in _cached_keys: + if k in generation_config or k in sampling_config or attrs[k] is None: del attrs[k] - # For pickling to work, the __module__ variable needs to be set to the - # frame where the class is created. Bypass this step in environments where - # sys._getframe is not defined (Jython for example) or sys._getframe is not - # defined for arguments greater than 0 (IronPython). - try: cls.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__") - except (AttributeError, ValueError): pass + self.__openllm_extras__ = config_merger.merge(first_not_none(__openllm_extras__, default={}), {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__}) + self.generation_config = self["generation_class"](_internal=True, **generation_config) + self.sampling_config = self["sampling_class"].from_generation_config(self.generation_config, **sampling_config) - def __setattr__(self, attr: str, value: t.Any) -> None: - if attr in _reserved_namespace: raise ForbiddenAttributeError(f"{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}.") - super().__setattr__(attr, value) + # The rest of attrs should only be the attributes to be passed to __attrs_init__ + self.__attrs_init__(**attrs) - def __init__(self, *, generation_config: DictStrAny | None = None, __openllm_extras__: DictStrAny | None = None, **attrs: t.Any): - # create a copy of the keys as cache - _cached_keys = tuple(attrs.keys()) + # NOTE: These required fields should be at the top, as it will be kw_only - _generation_cl_dict = attr.fields_dict(self.__openllm_generation_class__) - if generation_config is None: generation_config = {k: v for k, v in attrs.items() if k in _generation_cl_dict} - else: generation_config = config_merger.merge(generation_config, {k: v for k, v in attrs.items() if k in _generation_cl_dict}) + # fmt: off + # update-config-stubs.py: start + # NOTE: ModelSettings arguments + @overload + def __getitem__(self, item: t.Literal["default_id"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["model_ids"]) -> ListStr: ... + @overload + def __getitem__(self, item: t.Literal["architecture"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["default_implementation"]) -> t.Dict[LiteralResourceSpec, LiteralRuntime]: ... + @overload + def __getitem__(self, item: t.Literal["url"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["requires_gpu"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["trust_remote_code"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["service_name"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["requirements"]) -> t.Optional[ListStr]: ... + @overload + def __getitem__(self, item: t.Literal["bettertransformer"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["model_type"]) -> t.Literal["causal_lm", "seq2seq_lm"]: ... + @overload + def __getitem__(self, item: t.Literal["runtime"]) -> t.Literal["transformers", "ggml"]: ... + @overload + def __getitem__(self, item: t.Literal["name_type"]) -> t.Optional[t.Literal["dasherize", "lowercase"]]: ... + @overload + def __getitem__(self, item: t.Literal["model_name"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["start_name"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["env"]) -> openllm.utils.EnvVarMixin: ... + @overload + def __getitem__(self, item: t.Literal["timeout"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["workers_per_resource"]) -> t.Union[int, float]: ... + @overload + def __getitem__(self, item: t.Literal["fine_tune_strategies"]) -> t.Dict[AdapterType, FineTuneConfig]: ... + @overload + def __getitem__(self, item: t.Literal["tokenizer_class"]) -> t.Optional[str]: ... + # NOTE: generation_class, sampling_class and extras arguments + @overload + def __getitem__(self, item: t.Literal["generation_class"]) -> t.Type[openllm._configuration.GenerationConfig]: ... + @overload + def __getitem__(self, item: t.Literal["sampling_class"]) -> t.Type[openllm._configuration.SamplingParams]: ... + @overload + def __getitem__(self, item: t.Literal["extras"]) -> t.Dict[str, t.Any]: ... + # NOTE: GenerationConfig arguments + @overload + def __getitem__(self, item: t.Literal["max_new_tokens"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["min_length"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["min_new_tokens"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["early_stopping"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["max_time"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["num_beams"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["num_beam_groups"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["penalty_alpha"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["use_cache"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["temperature"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["top_k"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["top_p"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["typical_p"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["epsilon_cutoff"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["eta_cutoff"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["diversity_penalty"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["repetition_penalty"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["encoder_repetition_penalty"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["length_penalty"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["no_repeat_ngram_size"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["bad_words_ids"]) -> t.List[t.List[int]]: ... + @overload + def __getitem__(self, item: t.Literal["force_words_ids"]) -> t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]]: ... + @overload + def __getitem__(self, item: t.Literal["renormalize_logits"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["constraints"]) -> t.List[Constraint]: ... + @overload + def __getitem__(self, item: t.Literal["forced_bos_token_id"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["forced_eos_token_id"]) -> t.Union[int, t.List[int]]: ... + @overload + def __getitem__(self, item: t.Literal["remove_invalid_values"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["exponential_decay_length_penalty"]) -> t.Tuple[int, float]: ... + @overload + def __getitem__(self, item: t.Literal["suppress_tokens"]) -> t.List[int]: ... + @overload + def __getitem__(self, item: t.Literal["begin_suppress_tokens"]) -> t.List[int]: ... + @overload + def __getitem__(self, item: t.Literal["forced_decoder_ids"]) -> t.List[t.List[int]]: ... + @overload + def __getitem__(self, item: t.Literal["num_return_sequences"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["output_attentions"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["output_hidden_states"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["output_scores"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["pad_token_id"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["bos_token_id"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["eos_token_id"]) -> t.Union[int, t.List[int]]: ... + @overload + def __getitem__(self, item: t.Literal["encoder_no_repeat_ngram_size"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["decoder_start_token_id"]) -> int: ... + # NOTE: SamplingParams arguments + @overload + def __getitem__(self, item: t.Literal["n"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["best_of"]) -> int: ... + @overload + def __getitem__(self, item: t.Literal["presence_penalty"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["frequency_penalty"]) -> float: ... + @overload + def __getitem__(self, item: t.Literal["use_beam_search"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["stop"]) -> t.List[str]: ... + @overload + def __getitem__(self, item: t.Literal["ignore_eos"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["logprobs"]) -> int: ... + # NOTE: PeftType arguments + @overload + def __getitem__(self, item: t.Literal["prompt_tuning"]) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["p_tuning"]) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["prefix_tuning"]) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["lora"]) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["adalora"]) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["adaption_prompt"]) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["ia3"]) -> dict[str, t.Any]: ... + # update-config-stubs.py: stop - sampling_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(self.__openllm_sampling_class__)} + # fmt: on - for k in _cached_keys: - if k in generation_config or k in sampling_config or attrs[k] is None: del attrs[k] + def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any: + """Allowing access LLMConfig as a dictionary. The order will always evaluate as. - self.__openllm_extras__ = config_merger.merge(first_not_none(__openllm_extras__, default={}), {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__}) - self.generation_config = self["generation_class"](_internal=True, **generation_config) - self.sampling_config = self["sampling_class"].from_generation_config(self.generation_config, **sampling_config) + __openllm_*__ > self.key > self.generation_config > self['fine_tune_strategies'] > __openllm_extras__ - # The rest of attrs should only be the attributes to be passed to __attrs_init__ - self.__attrs_init__(**attrs) + This method is purely for convenience, and should not be used for performance critical code. + """ + if item is None: raise TypeError(f"{self} doesn't understand how to index None.") + item = inflection.underscore(item) + if item in _reserved_namespace: raise ForbiddenAttributeError(f"'{item}' is a reserved namespace for {self.__class__} and should not be access nor modified.") + internal_attributes = f"__openllm_{item}__" - # NOTE: These required fields should be at the top, as it will be kw_only + if hasattr(self, internal_attributes): return getattr(self, internal_attributes) + elif hasattr(self, item): return getattr(self, item) + elif hasattr(self.__openllm_generation_class__, item): return getattr(self.generation_config, item) + elif hasattr(self.__openllm_sampling_class__, item): return getattr(self.sampling_config, item) + elif item in self.__class__.__openllm_fine_tune_strategies__: return self.__class__.__openllm_fine_tune_strategies__[item] + elif item in self.__openllm_extras__: return self.__openllm_extras__[item] + else: raise KeyError(item) + def __getattribute__(self, item: str) -> t.Any: + if item in _reserved_namespace: raise ForbiddenAttributeError(f"'{item}' belongs to a private namespace for {self.__class__} and should not be access nor modified.") + return _object_getattribute.__get__(self)(item) + def __len__(self) -> int: return len(self.__openllm_accepted_keys__) + len(self.__openllm_extras__) + def keys(self) -> list[str]: return list(self.__openllm_accepted_keys__) + list(self.__openllm_extras__) + def values(self) -> list[t.Any]: return ([getattr(self, k.name) for k in attr.fields(self.__class__)] + [getattr(self.generation_config, k.name) for k in attr.fields(self.__openllm_generation_class__)] + [getattr(self.sampling_config, k.name) for k in attr.fields(self.__openllm_sampling_class__)] + list(self.__openllm_extras__.values())) + def items(self) -> list[tuple[str, t.Any]]: return ([(k.name, getattr(self, k.name)) for k in attr.fields(self.__class__)] + [(k.name, getattr(self.generation_config, k.name)) for k in attr.fields(self.__openllm_generation_class__)] + [(k.name, getattr(self.sampling_config, k.name)) for k in attr.fields(self.__openllm_sampling_class__)] + list(self.__openllm_extras__.items())) + def __iter__(self) -> t.Iterable[str]: return iter(self.keys()) + def __contains__(self, item: t.Any) -> bool: + if item in self.__openllm_extras__: return True + return item in self.__openllm_accepted_keys__ - # update-config-stubs.py: start - # NOTE: ModelSettings arguments - @overload - def __getitem__(self, item: t.Literal["default_id"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["model_ids"]) -> ListStr: ... - @overload - def __getitem__(self, item: t.Literal["architecture"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["default_implementation"]) -> t.Dict[LiteralResourceSpec, LiteralRuntime]: ... - @overload - def __getitem__(self, item: t.Literal["url"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["requires_gpu"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["trust_remote_code"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["service_name"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["requirements"]) -> t.Optional[ListStr]: ... - @overload - def __getitem__(self, item: t.Literal["bettertransformer"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["model_type"]) -> t.Literal["causal_lm", "seq2seq_lm"]: ... - @overload - def __getitem__(self, item: t.Literal["runtime"]) -> t.Literal["transformers", "ggml"]: ... - @overload - def __getitem__(self, item: t.Literal["name_type"]) -> t.Optional[t.Literal["dasherize", "lowercase"]]: ... - @overload - def __getitem__(self, item: t.Literal["model_name"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["start_name"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["env"]) -> openllm.utils.EnvVarMixin: ... - @overload - def __getitem__(self, item: t.Literal["timeout"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["workers_per_resource"]) -> t.Union[int, float]: ... - @overload - def __getitem__(self, item: t.Literal["fine_tune_strategies"]) -> t.Dict[AdapterType, FineTuneConfig]: ... - @overload - def __getitem__(self, item: t.Literal["tokenizer_class"]) -> t.Optional[str]: ... - # NOTE: generation_class, sampling_class and extras arguments - @overload - def __getitem__(self, item: t.Literal["generation_class"]) -> t.Type[openllm._configuration.GenerationConfig]: ... - @overload - def __getitem__(self, item: t.Literal["sampling_class"]) -> t.Type[openllm._configuration.SamplingParams]: ... - @overload - def __getitem__(self, item: t.Literal["extras"]) -> t.Dict[str, t.Any]: ... - # NOTE: GenerationConfig arguments - @overload - def __getitem__(self, item: t.Literal["max_new_tokens"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["min_length"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["min_new_tokens"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["early_stopping"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["max_time"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["num_beams"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["num_beam_groups"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["penalty_alpha"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["use_cache"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["temperature"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["top_k"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["top_p"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["typical_p"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["epsilon_cutoff"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["eta_cutoff"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["diversity_penalty"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["repetition_penalty"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["encoder_repetition_penalty"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["length_penalty"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["no_repeat_ngram_size"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["bad_words_ids"]) -> t.List[t.List[int]]: ... - @overload - def __getitem__(self, item: t.Literal["force_words_ids"]) -> t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]]: ... - @overload - def __getitem__(self, item: t.Literal["renormalize_logits"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["constraints"]) -> t.List[Constraint]: ... - @overload - def __getitem__(self, item: t.Literal["forced_bos_token_id"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["forced_eos_token_id"]) -> t.Union[int, t.List[int]]: ... - @overload - def __getitem__(self, item: t.Literal["remove_invalid_values"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["exponential_decay_length_penalty"]) -> t.Tuple[int, float]: ... - @overload - def __getitem__(self, item: t.Literal["suppress_tokens"]) -> t.List[int]: ... - @overload - def __getitem__(self, item: t.Literal["begin_suppress_tokens"]) -> t.List[int]: ... - @overload - def __getitem__(self, item: t.Literal["forced_decoder_ids"]) -> t.List[t.List[int]]: ... - @overload - def __getitem__(self, item: t.Literal["num_return_sequences"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["output_attentions"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["output_hidden_states"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["output_scores"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["pad_token_id"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["bos_token_id"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["eos_token_id"]) -> t.Union[int, t.List[int]]: ... - @overload - def __getitem__(self, item: t.Literal["encoder_no_repeat_ngram_size"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["decoder_start_token_id"]) -> int: ... - # NOTE: SamplingParams arguments - @overload - def __getitem__(self, item: t.Literal["n"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["best_of"]) -> int: ... - @overload - def __getitem__(self, item: t.Literal["presence_penalty"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["frequency_penalty"]) -> float: ... - @overload - def __getitem__(self, item: t.Literal["use_beam_search"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["stop"]) -> t.List[str]: ... - @overload - def __getitem__(self, item: t.Literal["ignore_eos"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["logprobs"]) -> int: ... - # NOTE: PeftType arguments - @overload - def __getitem__(self, item: t.Literal["prompt_tuning"]) -> dict[str, t.Any]: ... - @overload - def __getitem__(self, item: t.Literal["p_tuning"]) -> dict[str, t.Any]: ... - @overload - def __getitem__(self, item: t.Literal["prefix_tuning"]) -> dict[str, t.Any]: ... - @overload - def __getitem__(self, item: t.Literal["lora"]) -> dict[str, t.Any]: ... - @overload - def __getitem__(self, item: t.Literal["adalora"]) -> dict[str, t.Any]: ... - @overload - def __getitem__(self, item: t.Literal["adaption_prompt"]) -> dict[str, t.Any]: ... - @overload - def __getitem__(self, item: t.Literal["ia3"]) -> dict[str, t.Any]: ... - # update-config-stubs.py: stop + @classmethod + def model_derivate(cls, name: str | None = None, **attrs: t.Any) -> LLMConfig: + """A helper class to generate a new LLMConfig class with additional attributes. - def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any: - """Allowing access LLMConfig as a dictionary. The order will always evaluate as. + This is useful to modify builtin __config__ value attributes. - __openllm_*__ > self.key > self.generation_config > self['fine_tune_strategies'] > __openllm_extras__ + ```python + class DollyV2Config(openllm.LLMConfig): + ... - This method is purely for convenience, and should not be used for performance critical code. - """ - if item is None: - raise TypeError(f"{self} doesn't understand how to index None.") - item = inflection.underscore(item) - if item in _reserved_namespace: - raise ForbiddenAttributeError( - f"'{item}' is a reserved namespace for {self.__class__} and should not be access nor modified." - ) - internal_attributes = f"__openllm_{item}__" - if hasattr(self, internal_attributes): - return getattr(self, internal_attributes) - elif hasattr(self, item): - return getattr(self, item) - elif hasattr(self.__openllm_generation_class__, item): - return getattr(self.generation_config, item) - elif hasattr(self.__openllm_sampling_class__, item): - return getattr(self.sampling_config, item) - elif item in self.__class__.__openllm_fine_tune_strategies__: - return self.__class__.__openllm_fine_tune_strategies__[item] - elif item in self.__openllm_extras__: - return self.__openllm_extras__[item] - else: - raise KeyError(item) + my_new_class = DollyV2Config.model_derivate(default_id='...') + ``` - def __getattribute__(self, item: str) -> t.Any: - if item in _reserved_namespace: raise ForbiddenAttributeError(f"'{item}' belongs to a private namespace for {self.__class__} and should not be access nor modified.") - return _object_getattribute.__get__(self)(item) + Args: + name: The name of the new class. + **attrs: The attributes to be added to the new class. This will override + any existing attributes with the same name. + """ + if not hasattr(cls, "__config__"): + raise ValueError("Cannot derivate a LLMConfig without __config__") + _new_cfg = {k: v for k, v in attrs.items() if k in attr.fields_dict(_ModelSettingsAttr)} + attrs = {k: v for k, v in attrs.items() if k not in _new_cfg} + new_cls = types.new_class(name or f"{cls.__name__.replace('Config', '')}DerivateConfig", (cls,), {}, lambda ns: ns.update({ + "__config__": config_merger.merge(copy.deepcopy(cls.__dict__["__config__"]), _new_cfg), "__base_config__": cls, # keep a reference for easy access + })) - def __len__(self) -> int: return len(self.__openllm_accepted_keys__) + len(self.__openllm_extras__) - def keys(self) -> list[str]: return list(self.__openllm_accepted_keys__) + list(self.__openllm_extras__) - def values(self) -> list[t.Any]: - return ( - [getattr(self, k.name) for k in attr.fields(self.__class__)] - + [getattr(self.generation_config, k.name) for k in attr.fields(self.__openllm_generation_class__)] - + [getattr(self.sampling_config, k.name) for k in attr.fields(self.__openllm_sampling_class__)] - + list(self.__openllm_extras__.values()) - ) - def items(self) -> list[tuple[str, t.Any]]: - return ( - [(k.name, getattr(self, k.name)) for k in attr.fields(self.__class__)] - + [ - (k.name, getattr(self.generation_config, k.name)) - for k in attr.fields(self.__openllm_generation_class__) - ] - + [(k.name, getattr(self.sampling_config, k.name)) for k in attr.fields(self.__openllm_sampling_class__)] - + list(self.__openllm_extras__.items()) - ) - def __iter__(self) -> t.Iterable[str]: return iter(self.keys()) - def __contains__(self, item: t.Any) -> bool: - if item in self.__openllm_extras__: return True - return item in self.__openllm_accepted_keys__ - @classmethod - def model_derivate(cls, name: str | None = None, **attrs: t.Any) -> LLMConfig: - """A helper class to generate a new LLMConfig class with additional attributes. + # For pickling to work, the __module__ variable needs to be set to the + # frame where the class is created. Bypass this step in environments where + # sys._getframe is not defined (Jython for example) or sys._getframe is not + # defined for arguments greater than 0 (IronPython). + try: new_cls.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): pass + return new_cls(**attrs) - This is useful to modify builtin __config__ value attributes. + def model_dump(self, flatten: bool = False, **_: t.Any) -> DictStrAny: + dumped = bentoml_cattr.unstructure(self) + generation_config = bentoml_cattr.unstructure(self.generation_config) + sampling_config = bentoml_cattr.unstructure(self.sampling_config) + if flatten: dumped.update(generation_config) + else: dumped["generation_config"] = generation_config + dumped.update(sampling_config) + return dumped + def model_dump_json(self, **kwargs: t.Any) -> bytes: return orjson.dumps(self.model_dump(**kwargs)) + @classmethod + def model_construct_json(cls, json_str: str | bytes) -> t.Self: + try: attrs = orjson.loads(json_str) + except orjson.JSONDecodeError as err: raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}") from None + return bentoml_cattr.structure(attrs, cls) + @classmethod + def model_construct_env(cls, **attrs: t.Any) -> t.Self: + """A helpers that respect configuration values environment variables.""" + attrs = {k: v for k, v in attrs.items() if v is not None} + model_config = cls.__openllm_env__.config + env_json_string = os.environ.get(model_config, None) - ```python - class DollyV2Config(openllm.LLMConfig): - ... + config_from_env: DictStrAny = {} + if env_json_string is not None: + try: config_from_env = orjson.loads(env_json_string) + except orjson.JSONDecodeError as e: raise RuntimeError(f"Failed to parse '{model_config}' as valid JSON string.") from e - my_new_class = DollyV2Config.model_derivate(default_id='...') - ``` + if "generation_config" in attrs: + generation_config = attrs.pop("generation_config") + if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") + else: generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(cls.__openllm_generation_class__)} - Args: - name: The name of the new class. - **attrs: The attributes to be added to the new class. This will override - any existing attributes with the same name. - """ - if not hasattr(cls, "__config__"): - raise ValueError("Cannot derivate a LLMConfig without __config__") - _new_cfg = {k: v for k, v in attrs.items() if k in attr.fields_dict(_ModelSettingsAttr)} - attrs = {k: v for k, v in attrs.items() if k not in _new_cfg} - new_cls = types.new_class( - name or f"{cls.__name__.replace('Config', '')}DerivateConfig", - (cls,), - {}, - lambda ns: ns.update( - { - "__config__": config_merger.merge(copy.deepcopy(cls.__dict__["__config__"]), _new_cfg), - "__base_config__": cls, # keep a reference for easy access - } - ), - ) + for k in tuple(attrs.keys()): + if k in generation_config: del attrs[k] - # For pickling to work, the __module__ variable needs to be set to the - # frame where the class is created. Bypass this step in environments where - # sys._getframe is not defined (Jython for example) or sys._getframe is not - # defined for arguments greater than 0 (IronPython). - try: new_cls.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__") - except (AttributeError, ValueError): pass - return new_cls(**attrs) - def model_dump(self, flatten: bool = False, **_: t.Any) -> DictStrAny: - dumped = bentoml_cattr.unstructure(self) - generation_config = bentoml_cattr.unstructure(self.generation_config) - sampling_config = bentoml_cattr.unstructure(self.sampling_config) - if flatten: dumped.update(generation_config) - else: dumped["generation_config"] = generation_config - dumped.update(sampling_config) - return dumped - def model_dump_json(self, **kwargs: t.Any) -> bytes: return orjson.dumps(self.model_dump(**kwargs)) - @classmethod - def model_construct_json(cls, json_str: str | bytes) -> t.Self: - try: - attrs = orjson.loads(json_str) - except orjson.JSONDecodeError as err: - raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}") from None - return bentoml_cattr.structure(attrs, cls) - @classmethod - def model_construct_env(cls, **attrs: t.Any) -> t.Self: - """A helpers that respect configuration values environment variables.""" - attrs = {k: v for k, v in attrs.items() if v is not None} + config_from_env.update(attrs) + config_from_env["generation_config"] = generation_config + return bentoml_cattr.structure(config_from_env, cls) + def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]: + """Parse given click attributes into a LLMConfig and return the remaining click attributes.""" + llm_config_attrs: DictStrAny = {"generation_config": {}} + key_to_remove: ListStr = [] + for k, v in attrs.items(): + if k.startswith(f"{self['model_name']}_generation_"): + llm_config_attrs["generation_config"][k[len(self["model_name"] + "_generation_"):]] = v + key_to_remove.append(k) + elif k.startswith(f"{self['model_name']}_sampling_"): + llm_config_attrs[k[len(self["model_name"] + "_sampling_"):]] = v + key_to_remove.append(k) + elif k.startswith(f"{self['model_name']}_"): + llm_config_attrs[k[len(self["model_name"] + "_"):]] = v + key_to_remove.append(k) + return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove} - model_config = cls.__openllm_env__.config + @overload + def to_generation_config(self, return_as_dict: t.Literal[False] = False) -> transformers.GenerationConfig: ... + @overload + def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny: ... + def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny: + config = transformers.GenerationConfig(**bentoml_cattr.unstructure(self.generation_config)) + return config.to_dict() if return_as_dict else config + @requires_dependencies("vllm", extra="vllm") + def to_sampling_config(self) -> vllm.SamplingParams: return self.sampling_config.to_vllm() + @classmethod + def to_click_options(cls, f: AnyCallable) -> click.Command: + """Convert current configuration to click options. - env_json_string = os.environ.get(model_config, None) + This can be used as a decorator for click commands. - config_from_env: DictStrAny = {} - if env_json_string is not None: - try: config_from_env = orjson.loads(env_json_string) - except orjson.JSONDecodeError as e: raise RuntimeError(f"Failed to parse '{model_config}' as valid JSON string.") from e + > **Note**: that the identifier for all LLMConfig will be prefixed with '_*', and the generation config + will be prefixed with '_generation_*'. + """ + for name, field in attr.fields_dict(cls.__openllm_generation_class__).items(): + ty = cls.__openllm_hints__.get(name) + # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. + if t.get_origin(ty) is t.Union: continue + f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_generation=True)(f) + f = cog.optgroup.group(f"{cls.__openllm_generation_class__.__name__} generation options")(f) - if "generation_config" in attrs: - generation_config = attrs.pop("generation_config") - if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") - else: generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(cls.__openllm_generation_class__)} + for name, field in attr.fields_dict(cls.__openllm_sampling_class__).items(): + ty = cls.__openllm_hints__.get(name) + # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. + if t.get_origin(ty) is t.Union: continue + f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_sampling=True)(f) + f = cog.optgroup.group(f"{cls.__openllm_sampling_class__.__name__} sampling options")(f) - for k in tuple(attrs.keys()): - if k in generation_config: del attrs[k] + total_keys = set(attr.fields_dict(cls.__openllm_generation_class__)) | set(attr.fields_dict(cls.__openllm_sampling_class__)) - config_from_env.update(attrs) - config_from_env["generation_config"] = generation_config - return bentoml_cattr.structure(config_from_env, cls) - def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]: - """Parse given click attributes into a LLMConfig and return the remaining click attributes.""" - llm_config_attrs: DictStrAny = {"generation_config": {}} - key_to_remove: ListStr = [] + if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return f - for k, v in attrs.items(): - if k.startswith(f"{self['model_name']}_generation_"): - llm_config_attrs["generation_config"][k[len(self["model_name"] + "_generation_") :]] = v - key_to_remove.append(k) - elif k.startswith(f"{self['model_name']}_sampling_"): - llm_config_attrs[k[len(self["model_name"] + "_sampling_") :]] = v - key_to_remove.append(k) - elif k.startswith(f"{self['model_name']}_"): - llm_config_attrs[k[len(self["model_name"] + "_") :]] = v - key_to_remove.append(k) + # We pop out 'generation_config' as it is a attribute that we don't need to expose to CLI. + for name, field in attr.fields_dict(cls).items(): + ty = cls.__openllm_hints__.get(name) + # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. + if t.get_origin(ty) is t.Union or name == "generation_config" or name == "sampling_config": continue + f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty)(f) - return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove} - @overload - def to_generation_config(self, return_as_dict: t.Literal[False] = False) -> transformers.GenerationConfig: ... - @overload - def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny: ... - def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny: - config = transformers.GenerationConfig(**bentoml_cattr.unstructure(self.generation_config)) - return config.to_dict() if return_as_dict else config - @requires_dependencies("vllm", extra="vllm") - def to_sampling_config(self) -> vllm.SamplingParams: return self.sampling_config.to_vllm() - @classmethod - def to_click_options(cls, f: AnyCallable) -> click.Command: - """Convert current configuration to click options. + return cog.optgroup.group(f"{cls.__name__} options")(f) - This can be used as a decorator for click commands. - - > **Note**: that the identifier for all LLMConfig will be prefixed with '_*', and the generation config - will be prefixed with '_generation_*'. - """ - for name, field in attr.fields_dict(cls.__openllm_generation_class__).items(): - ty = cls.__openllm_hints__.get(name) - # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. - if t.get_origin(ty) is t.Union: continue - f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_generation=True)(f) - f = cog.optgroup.group(f"{cls.__openllm_generation_class__.__name__} generation options")(f) - - for name, field in attr.fields_dict(cls.__openllm_sampling_class__).items(): - ty = cls.__openllm_hints__.get(name) - # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. - if t.get_origin(ty) is t.Union: continue - f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_sampling=True)(f) - f = cog.optgroup.group(f"{cls.__openllm_sampling_class__.__name__} sampling options")(f) - - total_keys = set(attr.fields_dict(cls.__openllm_generation_class__)) | set(attr.fields_dict(cls.__openllm_sampling_class__)) - - if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return f - - # We pop out 'generation_config' as it is a attribute that we don't need to expose to CLI. - for name, field in attr.fields_dict(cls).items(): - ty = cls.__openllm_hints__.get(name) - # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. - if t.get_origin(ty) is t.Union or name == "generation_config" or name == "sampling_config": continue - f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty)(f) - - return cog.optgroup.group(f"{cls.__name__} options")(f) - # holds a mapping from self.__openllm_model_type__ to peft.TaskType - @classmethod - def peft_task_type(cls) -> str: return _PEFT_TASK_TYPE_TARGET_MAPPING[cls.__openllm_model_type__] - @classmethod - def default_implementation(cls) -> LiteralRuntime: return first_not_none(cls.__openllm_env__["framework_value"], default=get_default_implementation(cls.__openllm_default_implementation__)) - - -bentoml_cattr.register_unstructure_hook_factory( - lambda cls: lenient_issubclass(cls, LLMConfig), - lambda cls: make_dict_unstructure_fn( - cls, - bentoml_cattr, - _cattrs_omit_if_default=False, - _cattrs_use_linecache=True, - ), -) + # holds a mapping from self.__openllm_model_type__ to peft.TaskType + @classmethod + def peft_task_type(cls) -> str: return _PEFT_TASK_TYPE_TARGET_MAPPING[cls.__openllm_model_type__] + @classmethod + def default_implementation(cls) -> LiteralRuntime: return first_not_none(cls.__openllm_env__["framework_value"], default=get_default_implementation(cls.__openllm_default_implementation__)) +bentoml_cattr.register_unstructure_hook_factory(lambda cls: lenient_issubclass(cls, LLMConfig), lambda cls: make_dict_unstructure_fn(cls, bentoml_cattr, _cattrs_omit_if_default=False, _cattrs_use_linecache=True)) def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig: - """Structure a dictionary to a LLMConfig object. + """Structure a dictionary to a LLMConfig object. - Essentially, if the given dictionary contains a 'generation_config' key, then we will - use it for LLMConfig.generation_config + Essentially, if the given dictionary contains a 'generation_config' key, then we will + use it for LLMConfig.generation_config - Otherwise, we will filter out all keys are first in LLMConfig, parse it in, then - parse the remaining keys into LLMConfig.generation_config - """ - if not LazyType(DictStrAny).isinstance(data): - raise RuntimeError(f"Expected a dictionary, but got {type(data)}") - - cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_accepted_keys__} - generation_cls_fields = attr.fields_dict(cls.__openllm_generation_class__) - if "generation_config" in data: - generation_config = data.pop("generation_config") - if not LazyType(DictStrAny).isinstance(generation_config): - raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") - config_merger.merge(generation_config, {k: v for k, v in data.items() if k in generation_cls_fields}) - else: - generation_config = {k: v for k, v in data.items() if k in generation_cls_fields} - # The rest should be passed to extras - data = {k: v for k, v in data.items() if k not in cls.__openllm_accepted_keys__} - - return cls(generation_config=generation_config, __openllm_extras__=data, **cls_attrs) + Otherwise, we will filter out all keys are first in LLMConfig, parse it in, then + parse the remaining keys into LLMConfig.generation_config + """ + if not LazyType(DictStrAny).isinstance(data): raise RuntimeError(f"Expected a dictionary, but got {type(data)}") + cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_accepted_keys__} + generation_cls_fields = attr.fields_dict(cls.__openllm_generation_class__) + if "generation_config" in data: + generation_config = data.pop("generation_config") + if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") + config_merger.merge(generation_config, {k: v for k, v in data.items() if k in generation_cls_fields}) + else: generation_config = {k: v for k, v in data.items() if k in generation_cls_fields} + # The rest should be passed to extras + data = {k: v for k, v in data.items() if k not in cls.__openllm_accepted_keys__} + return cls(generation_config=generation_config, __openllm_extras__=data, **cls_attrs) bentoml_cattr.register_structure_hook_func(lambda cls: lenient_issubclass(cls, LLMConfig), structure_llm_config) -openllm_home = os.path.expanduser( - os.getenv( - "OPENLLM_HOME", - os.path.join(os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")), "openllm"), - ) -) +openllm_home = os.path.expanduser(os.getenv("OPENLLM_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")), "openllm"))) diff --git a/src/openllm/_generation.py b/src/openllm/_generation.py index c0ee0e28..dbe58952 100644 --- a/src/openllm/_generation.py +++ b/src/openllm/_generation.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Generation utilities to be reused throughout.""" from __future__ import annotations import typing as t @@ -19,29 +18,12 @@ import typing as t import transformers if t.TYPE_CHECKING: - import torch - + import torch class StopSequenceCriteria(transformers.StoppingCriteria): - """This class used to stop generation when a seq of tokens are met. - - Args: - stop_sequences: `str` or `list[str]` of the sequence (list of sequences) on which to stop execution. - tokenizer: Tokenizer to be used to decode the model outputs. - """ - - def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer): - if isinstance(stop_sequences, str): - stop_sequences = [stop_sequences] - self.stop_sequences = stop_sequences - self.tokenizer: transformers.PreTrainedTokenizer = tokenizer - - def __call__(self, input_ids: torch.Tensor, scores: t.Any, **attrs: t.Any) -> bool: - decoded_output = self.tokenizer.decode(input_ids.tolist()[0]) - return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences) - - + def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer): + if isinstance(stop_sequences, str): stop_sequences = [stop_sequences] + self.stop_sequences,self.tokenizer = stop_sequences, tokenizer + def __call__(self, input_ids: torch.Tensor, scores: t.Any, **attrs: t.Any) -> bool: return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences) class StopOnTokens(transformers.StoppingCriteria): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: t.Any) -> bool: - stop_ids = {50278, 50279, 50277, 1, 0} - return t.cast(int, input_ids[0][-1]) in stop_ids + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: t.Any) -> bool: return t.cast(int, input_ids[0][-1]) in {50278, 50279, 50277, 1, 0} diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index c12e3153..5abe4041 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -74,92 +74,95 @@ from .utils import validate_is_path # NOTE: We need to do this so that overload can register # correct overloads to typing registry if sys.version_info[:2] >= (3, 11): - from typing import NotRequired - from typing import overload + from typing import NotRequired + from typing import overload else: - from typing_extensions import NotRequired - from typing_extensions import overload + from typing_extensions import NotRequired + from typing_extensions import overload if t.TYPE_CHECKING: - import auto_gptq as autogptq - import peft - import torch - import vllm + import auto_gptq as autogptq + import peft + import torch + import vllm - import transformers + import transformers - from ._configuration import PeftType - from ._types import AdaptersMapping - from ._types import AdaptersTuple - from ._types import AnyCallable - from ._types import DictStrAny - from ._types import ListStr - from ._types import LiteralRuntime - from ._types import LLMEmbeddings - from ._types import LLMRunnable - from ._types import LLMRunner - from ._types import ModelSignatureDict as _ModelSignatureDict - from ._types import PeftAdapterOutput - from ._types import TupleAny - from .utils.representation import ReprArgs + from ._configuration import PeftType + from ._types import AdaptersMapping + from ._types import AdaptersTuple + from ._types import AnyCallable + from ._types import DictStrAny + from ._types import ListStr + from ._types import LiteralRuntime + from ._types import LLMEmbeddings + from ._types import LLMRunnable + from ._types import LLMRunner + from ._types import ModelSignatureDict as _ModelSignatureDict + from ._types import PeftAdapterOutput + from ._types import TupleAny + from .utils.representation import ReprArgs - UserDictAny = collections.UserDict[str, t.Any] - ResolvedAdaptersMapping = dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] + UserDictAny = collections.UserDict[str, t.Any] + ResolvedAdaptersMapping = dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] else: - DictStrAny = dict - TupleAny = tuple - UserDictAny = collections.UserDict - LLMRunnable = bentoml.Runnable - LLMRunner = bentoml.Runner - LLMEmbeddings = dict + DictStrAny = dict + TupleAny = tuple + UserDictAny = collections.UserDict + LLMRunnable = bentoml.Runnable + LLMRunner = bentoml.Runner + LLMEmbeddings = dict - autogptq = LazyLoader("autogptq", globals(), "auto_gptq") - vllm = LazyLoader("vllm", globals(), "vllm") - transformers = LazyLoader("transformers", globals(), "transformers") - torch = LazyLoader("torch", globals(), "torch") - peft = LazyLoader("peft", globals(), "peft") + autogptq = LazyLoader("autogptq", globals(), "auto_gptq") + vllm = LazyLoader("vllm", globals(), "vllm") + transformers = LazyLoader("transformers", globals(), "transformers") + torch = LazyLoader("torch", globals(), "torch") + peft = LazyLoader("peft", globals(), "peft") logger = logging.getLogger(__name__) class ModelSignatureDict(t.TypedDict, total=False): - batchable: bool - batch_dim: t.Union[t.Tuple[int, int], int] - input_spec: NotRequired[t.Union[t.Any, t.Tuple[t.Any]]] - output_spec: NotRequired[t.Any] + batchable: bool + batch_dim: t.Union[t.Tuple[int, int], int] + input_spec: NotRequired[t.Union[t.Any, t.Tuple[t.Any]]] + output_spec: NotRequired[t.Any] -def normalise_model_name(name: str) -> str: return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else re.sub("[^a-zA-Z0-9]+", "-", name) +def normalise_model_name(name: str) -> str: + return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else re.sub("[^a-zA-Z0-9]+", "-", name) # the below is similar to peft.utils.other.CONFIG_NAME PEFT_CONFIG_NAME = "adapter_config.json" def resolve_peft_config_type(adapter_map: dict[str, str | None]) -> AdaptersMapping: - """Resolve the type of the PeftConfig given the adapter_map. + """Resolve the type of the PeftConfig given the adapter_map. - This is similar to how PeftConfig resolve its config type. + This is similar to how PeftConfig resolve its config type. - Args: - adapter_map: The given mapping from either SDK or CLI. See CLI docs for more information. - """ - resolved: AdaptersMapping = {} - _has_set_default = False - for path_or_adapter_id, name in adapter_map.items(): - resolve_name = name - if resolve_name is None: - if _has_set_default: raise ValueError("Only one adapter can be set as default.") - resolve_name = "default" - _has_set_default = True - if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)): - config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME) - else: - try: config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME) - except Exception as err: raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") from err - with open(config_file, "r") as file: - resolved_config = orjson.loads(file.read()) - # all peft_type should be available in PEFT_CONFIG_NAME - _peft_type: AdapterType = resolved_config["peft_type"].lower() - if _peft_type not in resolved: resolved[_peft_type] = () - resolved[_peft_type] += (_AdaptersTuple((path_or_adapter_id, resolve_name, resolved_config)),) - return resolved + Args: + adapter_map: The given mapping from either SDK or CLI. See CLI docs for more information. + """ + resolved: AdaptersMapping = {} + _has_set_default = False + for path_or_adapter_id, name in adapter_map.items(): + resolve_name = name + if resolve_name is None: + if _has_set_default: raise ValueError("Only one adapter can be set as default.") + resolve_name = "default" + _has_set_default = True + if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)): + config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME) + else: + try: + config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME) + except Exception as err: + raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") from err + with open(config_file, "r") as file: + resolved_config = orjson.loads(file.read()) + # all peft_type should be available in PEFT_CONFIG_NAME + _peft_type: AdapterType = resolved_config["peft_type"].lower() + if _peft_type not in resolved: resolved[_peft_type] = () + resolved[_peft_type] += (_AdaptersTuple((path_or_adapter_id, resolve_name, resolved_config)),) + return resolved _reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"} @@ -167,105 +170,103 @@ M = t.TypeVar("M", bound="t.Union[transformers.PreTrainedModel, transformers.Pip T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]") def _default_post_init(self: LLM[t.Any, t.Any]) -> None: - if self.__llm_implementation__ == "pt" and is_torch_available(): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.__llm_implementation__ == "pt" and is_torch_available(): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class LLMInterface(ABC, t.Generic[M, T]): - """This defines the loose contract for all openllm.LLM implementations.""" - @property - def import_kwargs(self) -> tuple[DictStrAny, DictStrAny] | None: - """The default import kwargs to used when importing the model. + """This defines the loose contract for all openllm.LLM implementations.""" + @property + def import_kwargs(self) -> tuple[DictStrAny, DictStrAny] | None: + """The default import kwargs to used when importing the model. - This will be passed into 'openllm.LLM.import_model'. - It returns two dictionaries: one for model kwargs and one for tokenizer kwargs. + This will be passed into 'openllm.LLM.import_model'. + It returns two dictionaries: one for model kwargs and one for tokenizer kwargs. - Returns: - Optional tuple of model kwargs and tokenizer kwargs - """ - def embeddings(self, prompts: list[str]) -> LLMEmbeddings: - """The implementation for generating text embeddings from given prompt. + Returns: + Optional tuple of model kwargs and tokenizer kwargs + """ + def embeddings(self, prompts: list[str]) -> LLMEmbeddings: + """The implementation for generating text embeddings from given prompt. - It takes the prompt and output the embeddings for this given LLM. + It takes the prompt and output the embeddings for this given LLM. - Returns: - The embeddings for the given prompt. - """ - raise NotImplementedError - @abstractmethod - def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any: - """The implementation for text generation from given prompt. + Returns: + The embeddings for the given prompt. + """ + raise NotImplementedError + @abstractmethod + def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any: + """The implementation for text generation from given prompt. - It takes the prompt and 'generation_kwargs' from 'self.sanitize_parameters' and then - pass it to 'self.model.generate'. - """ - raise NotImplementedError - def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> t.Sequence[dict[t.Literal["generated_text"], str]]: - """The entrypoint for generating one prompt. + It takes the prompt and 'generation_kwargs' from 'self.sanitize_parameters' and then pass it to 'self.model.generate'. + """ + raise NotImplementedError + def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> t.Sequence[dict[t.Literal["generated_text"], str]]: + """The entrypoint for generating one prompt. - This provides additional stop tokens for generating per token level. - This is useful when running with agents, or initial streaming support. - """ - raise NotImplementedError - def generate_iterator(self, prompt: str, **attrs: t.Any) -> t.Iterator[t.Any]: - """The iterator version of `generate` function.""" - raise NotImplementedError("Currently generate_iterator requires SSE (Server-side events) support, which is not yet implemented.") - def sanitize_parameters(self, prompt: str, **attrs: t.Any) -> tuple[str, DictStrAny, DictStrAny]: - """This handler will sanitize all attrs and setup prompt text. + This provides additional stop tokens for generating per token level. This is useful when running with agents, or initial streaming support. + """ + raise NotImplementedError + def generate_iterator(self, prompt: str, **attrs: t.Any) -> t.Iterator[t.Any]: + """The iterator version of `generate` function.""" + raise NotImplementedError("Currently generate_iterator requires SSE (Server-side events) support, which is not yet implemented.") + def sanitize_parameters(self, prompt: str, **attrs: t.Any) -> tuple[str, DictStrAny, DictStrAny]: + """This handler will sanitize all attrs and setup prompt text. - It takes a prompt that is given by the user, attrs that can be parsed with the prompt. + It takes a prompt that is given by the user, attrs that can be parsed with the prompt. - Returns a tuple of three items: - - The attributes dictionary that can be passed into LLMConfig to generate a GenerationConfig - - The attributes dictionary that will be passed into `self.postprocess_generate`. - """ - return prompt, attrs, attrs - def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> t.Any: - """This handler will postprocess generation results from LLM.generate and then output nicely formatted results (if the LLM decide to do so.). + Returns a tuple of three items: + - The attributes dictionary that can be passed into LLMConfig to generate a GenerationConfig + - The attributes dictionary that will be passed into `self.postprocess_generate`. + """ + return prompt, attrs, attrs + def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> t.Any: + """This handler will postprocess generation results from LLM.generate and then output nicely formatted results (if the LLM decide to do so.). - You can customize how the output of the LLM looks with this hook. By default, it is a simple echo. + You can customize how the output of the LLM looks with this hook. By default, it is a simple echo. - NOTE: this will be used from the client side. - """ - return generation_result - def llm_post_init(self) -> None: - """This function can be implemented if you need to initialized any additional variables that doesn't concern OpenLLM internals.""" - pass - def import_model(self, *args: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model: - """This function can be implemented if default import_model doesn't satisfy your needs. + NOTE: this will be used from the client side. + """ + return generation_result + def llm_post_init(self) -> None: + """This function can be implemented if you need to initialized any additional variables that doesn't concern OpenLLM internals.""" + pass + def import_model(self, *args: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model: + """This function can be implemented if default import_model doesn't satisfy your needs. - Note that tokenizer attrs can be accessed via ``llm.llm_parameters``. + Note that tokenizer attrs can be accessed via ``llm.llm_parameters``. - ```python - _, tokenizer_attrs = llm.llm_parameters - ``` + ```python + _, tokenizer_attrs = llm.llm_parameters + ``` - By default, `model_decls` and `model_attrs` is already sanitised and concatenated into `args` and `attrs` - """ - raise NotImplementedError - def load_model(self, *args: t.Any, **attrs: t.Any) -> M: - """This function can be implemented to override the default load_model behaviour. + By default, `model_decls` and `model_attrs` is already sanitised and concatenated into `args` and `attrs` + """ + raise NotImplementedError + def load_model(self, *args: t.Any, **attrs: t.Any) -> M: + """This function can be implemented to override the default load_model behaviour. - See falcon for example implementation. Tag can be accessed via ``self.tag`` - """ - raise NotImplementedError - def load_tokenizer(self, tag: bentoml.Tag, **attrs: t.Any) -> T: - """This function can be implemented to override how to load the tokenizer. + See falcon for example implementation. Tag can be accessed via ``self.tag`` + """ + raise NotImplementedError + def load_tokenizer(self, tag: bentoml.Tag, **attrs: t.Any) -> T: + """This function can be implemented to override how to load the tokenizer. - See falcon for example implementation. - """ - raise NotImplementedError - def save_pretrained(self, save_directory: str | Path, **attrs: t.Any) -> None: - """This function defines how this model can be saved to local store. + See falcon for example implementation. + """ + raise NotImplementedError + def save_pretrained(self, save_directory: str | Path, **attrs: t.Any) -> None: + """This function defines how this model can be saved to local store. - This will be called during ``import_model``. By default, it will use ``openllm.serialisation.save_pretrained``. - Additionally, the function signature are similar to ``transformers.PreTrainedModel.save_pretrained`` - This is useful during fine tuning. - """ - raise NotImplementedError - # NOTE: All fields below are attributes that can be accessed by users. - config_class: type[LLMConfig] - """The config class to use for this LLM. If you are creating a custom LLM, you must specify this class.""" - bettertransformer: bool - """Whether to load this LLM with FasterTransformer enabled. The order of loading is: + This will be called during ``import_model``. By default, it will use ``openllm.serialisation.save_pretrained``. + Additionally, the function signature are similar to ``transformers.PreTrainedModel.save_pretrained`` + This is useful during fine tuning. + """ + raise NotImplementedError + # NOTE: All fields below are attributes that can be accessed by users. + config_class: type[LLMConfig] + """The config class to use for this LLM. If you are creating a custom LLM, you must specify this class.""" + bettertransformer: bool + """Whether to load this LLM with FasterTransformer enabled. The order of loading is: - If pass within `for_model`, `from_pretrained` or `__init__`. - If `self.bettertransformer` is set within `llm_post_init`. @@ -273,19 +274,19 @@ class LLMInterface(ABC, t.Generic[M, T]): > **Note** that if LoRA is enabled, bettertransformer will be disabled. """ - device: "torch.device" - """The device to be used for this LLM. If the implementation is 'pt', then it will be torch.device, else string.""" - tokenizer_id: t.LiteralString | t.Literal["local"] - """optional tokenizer_id for loading with vLLM if the model supports vLLM.""" - # NOTE: The following will be populated by __init_subclass__, note that these should be immutable. - __llm_trust_remote_code__: bool - """This is used to determine during 'import_model' whether to trust remote code or not. + device: "torch.device" + """The device to be used for this LLM. If the implementation is 'pt', then it will be torch.device, else string.""" + tokenizer_id: t.LiteralString | t.Literal["local"] + """optional tokenizer_id for loading with vLLM if the model supports vLLM.""" + # NOTE: The following will be populated by __init_subclass__, note that these should be immutable. + __llm_trust_remote_code__: bool + """This is used to determine during 'import_model' whether to trust remote code or not. This works synonymous with `trust_remote_code` kwarg in transformers Auto classes. If not passed, then by default fallback to config_class['trust_remote_code'] """ - __llm_implementation__: LiteralRuntime - """This is used to determine which implementation that this LLM has. + __llm_implementation__: LiteralRuntime + """This is used to determine which implementation that this LLM has. Usually, this will inferred from class name, that follows the HuggingFace's naming convention: @@ -295,910 +296,850 @@ class LLMInterface(ABC, t.Generic[M, T]): An additional naming for all VLLM backend: VLLMLlama -> `vllm` """ - __llm_model__: M | None - """A reference to the actual model. Instead of access this directly, you should use `model` property instead.""" - __llm_tokenizer__: T | None - """A reference to the actual tokenizer. Instead of access this directly, you should use `tokenizer` property instead.""" - __llm_bentomodel__: bentoml.Model | None - """A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead.""" - __llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None - """A reference to the the cached LoRA adapter mapping.""" - __llm_supports_embeddings__: bool - """A boolean to determine whether models does implement ``LLM.embeddings``.""" - __llm_supports_generate__: bool - """A boolean to determine whether models does implement ``LLM.generate``.""" - __llm_supports_generate_one__: bool - """A boolean to determine whether models does implement ``LLM.generate_one``.""" - __llm_supports_generate_iterator__: bool - """A boolean to determine whether models does implement ``LLM.generate_iterator``.""" - if t.TYPE_CHECKING and not MYPY: - def __attrs_init__( - self, - config: LLMConfig, - quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None, - model_id: str, - runtime: t.Literal["ggml", "transformers"], - model_decls: TupleAny, - model_attrs: DictStrAny, - tokenizer_attrs: DictStrAny, - tag: bentoml.Tag, - adapters_mapping: AdaptersMapping | None, - model_version: str | None, - quantize_method: t.Literal["int8", "int4", "gptq"] | None, - serialisation_format: t.Literal["safetensors", "legacy"], - **attrs: t.Any, - ) -> None: - """Generated __attrs_init__ for openllm.LLM.""" + __llm_model__: M | None + """A reference to the actual model. Instead of access this directly, you should use `model` property instead.""" + __llm_tokenizer__: T | None + """A reference to the actual tokenizer. Instead of access this directly, you should use `tokenizer` property instead.""" + __llm_bentomodel__: bentoml.Model | None + """A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead.""" + __llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None + """A reference to the the cached LoRA adapter mapping.""" + __llm_supports_embeddings__: bool + """A boolean to determine whether models does implement ``LLM.embeddings``.""" + __llm_supports_generate__: bool + """A boolean to determine whether models does implement ``LLM.generate``.""" + __llm_supports_generate_one__: bool + """A boolean to determine whether models does implement ``LLM.generate_one``.""" + __llm_supports_generate_iterator__: bool + """A boolean to determine whether models does implement ``LLM.generate_iterator``.""" + if t.TYPE_CHECKING and not MYPY: + def __attrs_init__( + self, config: LLMConfig, quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None, model_id: str, runtime: t.Literal["ggml", "transformers"], model_decls: TupleAny, model_attrs: DictStrAny, tokenizer_attrs: DictStrAny, tag: bentoml.Tag, adapters_mapping: AdaptersMapping | None, model_version: str | None, + quantize_method: t.Literal["int8", "int4", "gptq"] | None, serialisation_format: t.Literal["safetensors", "legacy"], **attrs: t.Any, + ) -> None: + """Generated __attrs_init__ for openllm.LLM.""" if t.TYPE_CHECKING: - _R = t.TypeVar("_R") - class _import_model_wrapper(t.Generic[_R, M, T]): - def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R: ... - class _load_model_wrapper(t.Generic[M, T]): - def __call__(self, llm: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: ... - class _load_tokenizer_wrapper(t.Generic[M, T]): - def __call__(self, llm: LLM[M, T], **attrs: t.Any) -> T: ... - class _llm_post_init_wrapper(t.Generic[M, T]): - def __call__(self, llm: LLM[M, T]) -> T: ... - class _save_pretrained_wrapper(t.Generic[M, T]): - def __call__(self, llm: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None: ... + _R = t.TypeVar("_R") + + class _import_model_wrapper(t.Generic[_R, M, T]): + def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R: + ... + class _load_model_wrapper(t.Generic[M, T]): + def __call__(self, llm: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: + ... + class _load_tokenizer_wrapper(t.Generic[M, T]): + def __call__(self, llm: LLM[M, T], **attrs: t.Any) -> T: + ... + class _llm_post_init_wrapper(t.Generic[M, T]): + def __call__(self, llm: LLM[M, T]) -> T: + ... + class _save_pretrained_wrapper(t.Generic[M, T]): + def __call__(self, llm: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None: + ... _object_setattr = object.__setattr__ +# NOTE: the following wrapper are a light meta ops for wrapping default params to internal methods implementation. def _wrapped_import_model(f: _import_model_wrapper[bentoml.Model, M, T]) -> t.Callable[[LLM[M, T]], bentoml.Model]: - @functools.wraps(f) - def wrapper(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None, **attrs: t.Any) -> bentoml.Model: - trust_remote_code = first_not_none(trust_remote_code, default=self.__llm_trust_remote_code__) - # wrapped around custom init to provide some meta compression - # for all decls and attrs - (model_decls, model_attrs), _ = self.llm_parameters - decls = (*model_decls, *decls) - attrs = {**model_attrs, **attrs} - return f(self, *decls, trust_remote_code=trust_remote_code, **attrs) - return wrapper + @functools.wraps(f) + def wrapper(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None, **attrs: t.Any) -> bentoml.Model: + trust_remote_code: bool = first_not_none(trust_remote_code, default=self.__llm_trust_remote_code__) + (model_decls, model_attrs), _ = self.llm_parameters + decls = (*model_decls, *decls) + attrs = {**model_attrs, **attrs} + return f(self, *decls, trust_remote_code=trust_remote_code, **attrs) + + return wrapper _DEFAULT_TOKENIZER = "hf-internal-testing/llama-tokenizer" @requires_dependencies("vllm", extra="vllm") def get_engine_args(llm: LLM[M, T], tokenizer: str = _DEFAULT_TOKENIZER) -> vllm.EngineArgs: - return vllm.EngineArgs(model=llm._bentomodel.path, tokenizer=tokenizer, tokenizer_mode="auto", tensor_parallel_size=1 if device_count() < 2 else device_count(), dtype="auto", worker_use_ray=False) + return vllm.EngineArgs(model=llm._bentomodel.path, tokenizer=tokenizer, tokenizer_mode="auto", tensor_parallel_size=1 if device_count() < 2 else device_count(), dtype="auto", worker_use_ray=False) + def _wrapped_load_model(f: _load_model_wrapper[M, T]) -> t.Callable[[LLM[M, T]], M | vllm.LLMEngine]: - @functools.wraps(f) - def wrapper(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine: - # wrapped around custom init to provide some meta compression for all decls and attrs - # and add general vllm.LLMEngine for any vllm implementation. - if self.__llm_implementation__ == "vllm": - # TODO: Do some more processing with token_id once we support token streaming - tokenizer_id = self._bentomodel.path if self.tokenizer_id == "local" else self.tokenizer_id - return vllm.LLMEngine.from_engine_args(get_engine_args(self, tokenizer=tokenizer_id)) - else: - (model_decls, model_attrs), _ = self.llm_parameters - return f(self, *(*model_decls, *decls), **{**model_attrs, **attrs}) - return wrapper + @functools.wraps(f) + def wrapper(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine: + if self.__llm_implementation__ == "vllm": + # TODO: Do some more processing with token_id once we support token streaming + tokenizer_id = self._bentomodel.path if self.tokenizer_id == "local" else self.tokenizer_id + return vllm.LLMEngine.from_engine_args(get_engine_args(self, tokenizer=tokenizer_id)) + else: + (model_decls, model_attrs), _ = self.llm_parameters + return f(self, *(*model_decls, *decls), **{**model_attrs, **attrs}) + return wrapper def _wrapped_load_tokenizer(f: _load_tokenizer_wrapper[M, T]) -> t.Callable[[LLM[M, T]], T]: - @functools.wraps(f) - def wrapper(self: LLM[M, T], **tokenizer_attrs: t.Any) -> T: - _, model_tokenizer_attrs = self.llm_parameters - tokenizer_attrs = {**model_tokenizer_attrs, **tokenizer_attrs} - return f(self, **tokenizer_attrs) - return wrapper + @functools.wraps(f) + def wrapper(self: LLM[M, T], **tokenizer_attrs: t.Any) -> T: return f(self, **{**self.llm_parameters[-1], **tokenizer_attrs}) + return wrapper def _wrapped_llm_post_init(f: _llm_post_init_wrapper[M, T]) -> t.Callable[[LLM[M, T]], None]: - @functools.wraps(f) - def wrapper(self: LLM[M, T]) -> None: - _default_post_init(self) - f(self) - return wrapper + @functools.wraps(f) + def wrapper(self: LLM[M, T]) -> None: + _default_post_init(self) + f(self) + return wrapper def _wrapped_save_pretrained(f: _save_pretrained_wrapper[M, T]) -> t.Callable[[LLM[M, T], str | Path], None]: - @functools.wraps(f) - def wrapper(self: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None: - if isinstance(save_directory, Path): save_directory = str(save_directory) - if self.__llm_model__ is None: raise RuntimeError("Cannot 'save_pretrained' with unload model instance.") - if self.bettertransformer and self.__llm_implementation__ == "pt": _object_setattr(self, "__llm_model__", t.cast("transformers.PreTrainedModel", self.__llm_model__).reverse_bettertransformer()) - f(self, save_directory, **attrs) - return wrapper + @functools.wraps(f) + def wrapper(self: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None: + if isinstance(save_directory, Path): save_directory = str(save_directory) + if self.__llm_model__ is None: raise RuntimeError("Cannot 'save_pretrained' with unload model instance.") + if self.bettertransformer and self.__llm_implementation__ == "pt": _object_setattr(self, "__llm_model__", t.cast("transformers.PreTrainedModel", self.__llm_model__).reverse_bettertransformer()) + f(self, save_directory, **attrs) + return wrapper def _update_docstring(cls: LLM[M, T], fn: str) -> AnyCallable: - # update docstring for given entrypoint - original_fn = getattr(cls, fn, getattr(LLMInterface, fn)) - original_fn.__doc__ = original_fn.__doc__ or f"""\ + # update docstring for given entrypoint + original_fn = getattr(cls, fn, getattr(LLMInterface, fn)) + original_fn.__doc__ = original_fn.__doc__ or f"""\ {cls.__name__}'s implementation for {fn}. Note that if LoRA is enabled (via either SDK or CLI), `self.model` will become a `peft.PeftModel` The original model can then be accessed with 'self.model.get_base_model()'. """ - setattr(cls, fn, original_fn) - return original_fn - + setattr(cls, fn, original_fn) + return original_fn def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]], None]: - attributes = { - "import_model": _wrapped_import_model, - "load_model": _wrapped_load_model, - "load_tokenizer": _wrapped_load_tokenizer, - "llm_post_init": _wrapped_llm_post_init, - "save_pretrained": _wrapped_save_pretrained, - } - args: ListStr = [] - anns: DictStrAny = {} - lines: ListStr = [] - globs: DictStrAny = { - "cls": cls, - "_cached_LLMInterface_get": _object_getattribute.__get__(LLMInterface), - "__gen_docstring": _update_docstring, - } - # function initialisation - for func, impl in attributes.items(): - impl_name = f"__wrapped_{func}" - globs.update({f"__serialisation_{func}": getattr(serialisation, func, None), impl_name: impl}) - cached_func_name = f"_cached_{cls.__name__}_func" - if func == "llm_post_init": func_call = f"_impl_{cls.__name__}_{func}={cached_func_name}" - else: func_call = f"_impl_{cls.__name__}_{func}={cached_func_name} if {cached_func_name} is not _cached_LLMInterface_get('{func}') else __serialisation_{func}" - lines.extend( - [ - f"{cached_func_name}=cls.{func}", - func_call, - _setattr_class(func, f"{impl_name}(_impl_{cls.__name__}_{func})"), - ] - ) + attributes = {"import_model": _wrapped_import_model, "load_model": _wrapped_load_model, "load_tokenizer": _wrapped_load_tokenizer, "llm_post_init": _wrapped_llm_post_init, "save_pretrained": _wrapped_save_pretrained} + args: ListStr = [] + anns: DictStrAny = {} + lines: ListStr = [] + globs: DictStrAny = {"cls": cls, "_cached_LLMInterface_get": _object_getattribute.__get__(LLMInterface), "__gen_docstring": _update_docstring} + # function initialisation + for func, impl in attributes.items(): + impl_name = f"__wrapped_{func}" + globs.update({f"__serialisation_{func}": getattr(serialisation, func, None), impl_name: impl}) + cached_func_name = f"_cached_{cls.__name__}_func" + if func == "llm_post_init": func_call = f"_impl_{cls.__name__}_{func}={cached_func_name}" + else: func_call = f"_impl_{cls.__name__}_{func}={cached_func_name} if {cached_func_name} is not _cached_LLMInterface_get('{func}') else __serialisation_{func}" + lines.extend([f"{cached_func_name}=cls.{func}", func_call, _setattr_class(func, f"{impl_name}(_impl_{cls.__name__}_{func})"),]) - # cached attribute initialisation - interface_anns = codegen.get_annotations(LLMInterface) - for v in {"bentomodel", "model", "tokenizer", "adapter_map"}: - lines.append(_setattr_class(f"__llm_{v}__", None)) - anns[f"__llm_{v}__"] = interface_anns.get(f"__llm_{v}__") + # cached attribute initialisation + interface_anns = codegen.get_annotations(LLMInterface) + for v in {"bentomodel", "model", "tokenizer", "adapter_map"}: + lines.append(_setattr_class(f"__llm_{v}__", None)) + anns[f"__llm_{v}__"] = interface_anns.get(f"__llm_{v}__") - # boolean to determine whether LLM has defined an implementation for a function - for fn in {"generate", "generate_one", "generate_iterator", "embeddings"}: - key = f"__llm_supports_{fn}__" - lines.extend( - [ - _setattr_class(key, f"cls.{fn} is not _cached_LLMInterface_get('{fn}')"), - f"__gen_docstring(cls, '{fn}')", - ] - ) - anns[key] = interface_anns.get(key) - return codegen.generate_function(cls, "__assign_llm_attr", lines, args=("cls", *args), globs=globs, annotations=anns) + # boolean to determine whether LLM has defined an implementation for a function + for fn in {"generate", "generate_one", "generate_iterator", "embeddings"}: + key = f"__llm_supports_{fn}__" + lines.extend([_setattr_class(key, f"cls.{fn} is not _cached_LLMInterface_get('{fn}')"), f"__gen_docstring(cls, '{fn}')",]) + anns[key] = interface_anns.get(key) + return codegen.generate_function(cls, "__assign_llm_attr", lines, args=("cls", *args), globs=globs, annotations=anns) _AdaptersTuple: type[AdaptersTuple] = codegen.make_attr_tuple_class("AdaptersTuple", ["adapter_id", "name", "config"]) @attr.define(slots=True, repr=False, init=False) class LLM(LLMInterface[M, T], ReprMixin): - if t.TYPE_CHECKING: __name__: str - config: LLMConfig - """The config instance to use for this LLM. This will be created based on config_class and available + if t.TYPE_CHECKING: __name__: str + config: LLMConfig + """The config instance to use for this LLM. This will be created based on config_class and available when initialising the LLM.""" - quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None - """Quantisation config for quantised model on the fly.""" - _model_id: str - _runtime: t.Literal["ggml", "transformers"] - _model_decls: TupleAny - _model_attrs: DictStrAny - _tokenizer_attrs: DictStrAny - _tag: bentoml.Tag - _adapters_mapping: AdaptersMapping | None - _model_version: str - _quantize_method: t.Literal["int8", "int4", "gptq"] | None - _serialisation_format: t.Literal["safetensors", "legacy"] - @staticmethod - def _infer_implementation_from_name(name: str) -> tuple[LiteralRuntime, str]: - if name.startswith("Flax"): return "flax", name[4:] - elif name.startswith("TF"): return "tf", name[2:] - elif name.startswith("VLLM"): return "vllm", name[4:] - else: return "pt", name - def __init_subclass__(cls: type[LLM[M, T]]) -> None: - cd = cls.__dict__ - implementation, config_class_name = cls._infer_implementation_from_name(cls.__name__) - cls.__llm_implementation__ = implementation - config_class = AutoConfig.infer_class_from_name(config_class_name) - if "__openllm_internal__" in cd: - if "config_class" not in cd: cls.config_class = config_class - elif "config_class" not in cd: raise RuntimeError("Missing required key 'config_class'. Make sure to define it within the LLM subclass.") - _make_assignment_script(cls)(cls) - if "tokenizer_id" not in cd and cls.__llm_implementation__ == "vllm": cls.tokenizer_id = _DEFAULT_TOKENIZER + quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None + """Quantisation config for quantised model on the fly.""" + _model_id: str + _runtime: t.Literal["ggml", "transformers"] + _model_decls: TupleAny + _model_attrs: DictStrAny + _tokenizer_attrs: DictStrAny + _tag: bentoml.Tag + _adapters_mapping: AdaptersMapping | None + _model_version: str + _quantize_method: t.Literal["int8", "int4", "gptq"] | None + _serialisation_format: t.Literal["safetensors", "legacy"] + @staticmethod + def _infer_implementation_from_name(name: str) -> tuple[LiteralRuntime, str]: + if name.startswith("Flax"): return "flax", name[4:] + elif name.startswith("TF"): return "tf", name[2:] + elif name.startswith("VLLM"): return "vllm", name[4:] + else: return "pt", name - if implementation == "vllm": - def vllm_postprocess_generate(self: LLM["vllm.LLMEngine", T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str: return generation_result[0]["outputs"][0]["text"] - def vllm_generate(self: LLM["vllm.LLMEngine", T], prompt: str, **attrs: t.Any) -> list[dict[str, t.Any]]: - outputs: list[vllm.RequestOutput] = [] - # TODO: support prompt_token_ids - self.model.add_request(request_id=str(uuid.uuid4().hex), prompt=prompt, sampling_params=self.config.model_construct_env(**attrs).to_sampling_config()) - while self.model.has_unfinished_requests(): outputs.extend([r for r in self.model.step() if r.finished]) - return [unmarshal_vllm_outputs(i) for i in outputs] - _object_setattr(cls, "postprocess_generate", vllm_postprocess_generate) - _object_setattr(cls, "generate", vllm_generate) - # fmt: off - @overload - def __getitem__(self, item: t.Literal["trust_remote_code"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["implementation"]) -> LiteralRuntime: ... - @overload - def __getitem__(self, item: t.Literal["model"]) -> M | None: ... - @overload - def __getitem__(self, item: t.Literal["tokenizer"]) -> T | None: ... - @overload - def __getitem__(self, item: t.Literal["bentomodel"]) -> bentoml.Model | None: ... - @overload - def __getitem__(self, item: t.Literal["adapter_map"]) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None: ... - @overload - def __getitem__(self, item: t.Literal["supports_embeddings"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["supports_generate"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["supports_generate_one"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["supports_generate_iterator"]) -> bool: ... - # fmt: on - def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any: - if item is None: raise TypeError(f"{self} doesn't understand how to index None.") - item = inflection.underscore(item) - internal_attributes = f"__llm_{item}__" - if hasattr(self, internal_attributes): return getattr(self, internal_attributes) - elif hasattr(self, item): return getattr(self, item) - else: raise KeyError(item) - @classmethod - @overload - def from_pretrained(cls, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["int8", "int4"] = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ..., quantization_config: transformers.BitsAndBytesConfig | None = ..., serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any) -> LLM[M, T]: ... - @classmethod - @overload - def from_pretrained(cls, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["gptq"] = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ..., quantization_config: autogptq.BaseQuantizeConfig | None = ..., serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any) -> LLM[M, T]: ... - @classmethod - def from_pretrained( - cls, - model_id: str | None = None, - model_version: str | None = None, - llm_config: LLMConfig | None = None, - *args: t.Any, - runtime: t.Literal["ggml", "transformers"] | None = None, - quantize: t.Literal["int8", "int4", "gptq"] | None = None, - bettertransformer: str | bool | None = None, - adapter_id: str | None = None, - adapter_name: str | None = None, - adapter_map: dict[str, str | None] | None = None, - quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None = None, - serialisation: t.Literal["safetensors", "legacy"] = "safetensors", - **attrs: t.Any, - ) -> LLM[M, T]: - """Instantiate a pretrained LLM. + def __init_subclass__(cls: type[LLM[M, T]]) -> None: + cd = cls.__dict__ + implementation, config_class_name = cls._infer_implementation_from_name(cls.__name__) + cls.__llm_implementation__ = implementation + config_class = AutoConfig.infer_class_from_name(config_class_name) + if "__openllm_internal__" in cd: + if "config_class" not in cd: cls.config_class = config_class + elif "config_class" not in cd: raise RuntimeError("Missing required key 'config_class'. Make sure to define it within the LLM subclass.") + _make_assignment_script(cls)(cls) + if "tokenizer_id" not in cd and cls.__llm_implementation__ == "vllm": cls.tokenizer_id = _DEFAULT_TOKENIZER - ``LLM.from_pretrained`` follows the same design principle as HuggingFace's `from_pretrained` method, plus the following: + if implementation == "vllm": + def vllm_postprocess_generate(self: LLM["vllm.LLMEngine", T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str: return generation_result[0]["outputs"][0]["text"] + def vllm_generate(self: LLM["vllm.LLMEngine", T], prompt: str, **attrs: t.Any) -> list[dict[str, t.Any]]: + outputs: list[vllm.RequestOutput] = [] + # TODO: support prompt_token_ids + self.model.add_request(request_id=str(uuid.uuid4().hex), prompt=prompt, sampling_params=self.config.model_construct_env(**attrs).to_sampling_config()) + while self.model.has_unfinished_requests(): + outputs.extend([r for r in self.model.step() if r.finished]) + return [unmarshal_vllm_outputs(i) for i in outputs] - ### Optimization options: + _object_setattr(cls, "postprocess_generate", vllm_postprocess_generate) + _object_setattr(cls, "generate", vllm_generate) - > This is most notable during serving time. + # fmt: off + @overload + def __getitem__(self, item: t.Literal["trust_remote_code"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["implementation"]) -> LiteralRuntime: ... + @overload + def __getitem__(self, item: t.Literal["model"]) -> M | None: ... + @overload + def __getitem__(self, item: t.Literal["tokenizer"]) -> T | None: ... + @overload + def __getitem__(self, item: t.Literal["bentomodel"]) -> bentoml.Model | None: ... + @overload + def __getitem__(self, item: t.Literal["adapter_map"]) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None: ... + @overload + def __getitem__(self, item: t.Literal["supports_embeddings"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["supports_generate"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["supports_generate_one"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["supports_generate_iterator"]) -> bool: ... + def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any: + if item is None: raise TypeError(f"{self} doesn't understand how to index None.") + item = inflection.underscore(item) + internal_attributes = f"__llm_{item}__" + if hasattr(self, internal_attributes): return getattr(self, internal_attributes) + elif hasattr(self, item): return getattr(self, item) + else: raise KeyError(item) + @classmethod + @overload + def from_pretrained( + cls, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["int8", "int4"] = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ..., + quantization_config: transformers.BitsAndBytesConfig | None = ..., serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any + ) -> LLM[M, T]: ... + @classmethod + @overload + def from_pretrained( + cls, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["gptq"] = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ..., + quantization_config: autogptq.BaseQuantizeConfig | None = ..., serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any + ) -> LLM[M, T]: ... + # fmt: on + @classmethod + def from_pretrained( + cls, model_id: str | None = None, model_version: str | None = None, llm_config: LLMConfig | None = None, *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: str | bool | None = None, adapter_id: str | None = None, adapter_name: str | None = None, + adapter_map: dict[str, str | None] | None = None, quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None = None, serialisation: t.Literal["safetensors", "legacy"] = "safetensors", **attrs: t.Any, + ) -> LLM[M, T]: + """Instantiate a pretrained LLM. - - quantize: quantize the model with the given quantization method. Currently supported int8, int4 quantization - - bettertransformer: Apply FasterTransformer to given pretrained weight + ``LLM.from_pretrained`` follows the same design principle as HuggingFace's `from_pretrained` method, plus the following: - > Currently, the above two options are mutually exclusive. + ### Optimization options: - #### Quantisation options + > This is most notable during serving time. - For customising options for quantisation config, ``openllm.LLM`` accepts all arbitrary arguments that is passed to ``transformers.BitsAndBytesConfig`` - plus ``quantize`` value. For example, for ``int8`` quantisation, specify the following: - ```python - model = openllm.AutoLLM.from_pretrained("opt", quantize='int8', llm_int8_enable_fp32_cpu_offload=False) - ``` + - quantize: quantize the model with the given quantization method. Currently supported int8, int4 quantization + - bettertransformer: Apply FasterTransformer to given pretrained weight - For all GPTQ-related options, it accepts all value prefixed with `gptq_*`. The parsed value then could be parsed - to ``auto_gptq.BaseQuantizeConfig``. + > Currently, the above two options are mutually exclusive. - ### Adapter options: + #### Quantisation options - > This is used in conjunction with the fine-tuning features + For customising options for quantisation config, ``openllm.LLM`` accepts all arbitrary arguments that is passed to ``transformers.BitsAndBytesConfig`` + plus ``quantize`` value. For example, for ``int8`` quantisation, specify the following: + ```python + model = openllm.AutoLLM.from_pretrained("opt", quantize='int8', llm_int8_enable_fp32_cpu_offload=False) + ``` - - adapter_id: Optional [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to apply to said model. - - adapter_name: Optional name of the adapter to apply to said model. If not provided, it will be handled internally by OpenLLM. - - adapter_map: optional dictionary of adapter_id to adapter_name. Note that this is mutually exclusive with adapter_id/adapter_name arguments. + For all GPTQ-related options, it accepts all value prefixed with `gptq_*`. The parsed value then could be parsed + to ``auto_gptq.BaseQuantizeConfig``. - Args: - model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used. - > **Warning**: If custom path is passed, make sure it contains all available file to construct - > ``transformers.PretrainedConfig``, ``transformers.PreTrainedModel``, and ``transformers.PreTrainedTokenizer``. - model_name: Optional model name to be saved with this LLM. Default to None. It will be inferred automatically from model_id. - If model_id is a custom path, it will be the basename of the given path. - model_version: Optional version for this given model id. Default to None. This is useful for saving from custom path. - If set to None, the version will either be the git hash from given pretrained model, or the hash inferred - from last modified time of the given directory. - llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM - will use `config_class` to construct default configuration. - quantize: The quantization to use for this LLM. Defaults to None. Possible values - include int8, int4 and gptq. - runtime: Optional runtime to run this LLM. Default to 'transformers'. 'ggml' supports is working in progress. - quantization_config: The quantization config (`transformers.BitsAndBytesConfig` | `autogtpq.BaseQuantizeConfig`) to use. Note that this is mutually exclusive with `quantize` - serialisation: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors. - Default behaviour is similar to ``safe_serialization=False``. - bettertransformer: Whether to use BetterTransformer with this model. Defaults to False. - adapter_id: The [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to use for this LLM. Defaults to None. - adapter_name: The adapter name to use for this LLM. Defaults to None. - adapter_map: The adapter map to use for this LLM. Defaults to None. Note that this is mutually exclusive with adapter_id/adapter_name arguments. - *args: The args to be passed to the model. - **attrs: The kwargs to be passed to the model. - """ - cfg_cls = cls.config_class - model_id = first_not_none(model_id, cfg_cls.__openllm_env__["model_id_value"], default=cfg_cls.__openllm_default_id__) - if validate_is_path(model_id): model_id = resolve_filepath(model_id) - quantize = first_not_none(quantize, cfg_cls.__openllm_env__["quantize_value"], default=None) + ### Adapter options: - # quantization setup - if quantization_config and quantize: raise ValueError("'quantization_config' and 'quantize' are mutually exclusive. Either customise your quantization_config or use the 'quantize' argument.") - if quantization_config is None and quantize is not None: quantization_config, attrs = infer_quantisation_config(cls, quantize, **attrs) - # We will use safetensors for gptq - if quantize == "gptq": serialisation = "safetensors" - # We will use legacy format for vllm - elif cls.__llm_implementation__ == "vllm": serialisation = "legacy" + > This is used in conjunction with the fine-tuning features - # NOTE: LoRA adapter setup - if adapter_map and adapter_id: raise ValueError("'adapter_map' and 'adapter_id' are mutually exclusive. Either provide a 'adapter_map' ({adapter_id: adapter_name | None, ...}) or use the combination of adapter_id/adapter_name arguments. ") - if adapter_map is None and adapter_id is not None: adapter_map = {adapter_id: adapter_name} - if adapter_map is not None and not is_peft_available(): raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'") - if adapter_map: logger.debug("OpenLLM will apply the following adapters layers: %s", list(adapter_map)) + - adapter_id: Optional [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to apply to said model. + - adapter_name: Optional name of the adapter to apply to said model. If not provided, it will be handled internally by OpenLLM. + - adapter_map: optional dictionary of adapter_id to adapter_name. Note that this is mutually exclusive with adapter_id/adapter_name arguments. - if llm_config is None: - llm_config = cls.config_class.model_construct_env(**attrs) - # The rests of the kwargs that is not used by the config class should be stored into __openllm_extras__. - attrs = llm_config["extras"] + Args: + model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used. + > **Warning**: If custom path is passed, make sure it contains all available file to construct + > ``transformers.PretrainedConfig``, ``transformers.PreTrainedModel``, and ``transformers.PreTrainedTokenizer``. + model_name: Optional model name to be saved with this LLM. Default to None. It will be inferred automatically from model_id. + If model_id is a custom path, it will be the basename of the given path. + model_version: Optional version for this given model id. Default to None. This is useful for saving from custom path. + If set to None, the version will either be the git hash from given pretrained model, or the hash inferred + from last modified time of the given directory. + llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM + will use `config_class` to construct default configuration. + quantize: The quantization to use for this LLM. Defaults to None. Possible values + include int8, int4 and gptq. + runtime: Optional runtime to run this LLM. Default to 'transformers'. 'ggml' supports is working in progress. + quantization_config: The quantization config (`transformers.BitsAndBytesConfig` | `autogtpq.BaseQuantizeConfig`) to use. Note that this is mutually exclusive with `quantize` + serialisation: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors. + Default behaviour is similar to ``safe_serialization=False``. + bettertransformer: Whether to use BetterTransformer with this model. Defaults to False. + adapter_id: The [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to use for this LLM. Defaults to None. + adapter_name: The adapter name to use for this LLM. Defaults to None. + adapter_map: The adapter map to use for this LLM. Defaults to None. Note that this is mutually exclusive with adapter_id/adapter_name arguments. + *args: The args to be passed to the model. + **attrs: The kwargs to be passed to the model. + """ + cfg_cls = cls.config_class + model_id = first_not_none(model_id, cfg_cls.__openllm_env__["model_id_value"], cfg_cls.__openllm_default_id__) + if model_id is None: raise RuntimeError("Failed to resolve a valid model_id.") + if validate_is_path(model_id): model_id = resolve_filepath(model_id) + quantize = first_not_none(quantize, cfg_cls.__openllm_env__["quantize_value"], default=None) - try: - _tag = cls.generate_tag(model_id, model_version) - if _tag.version is None: raise ValueError(f"Failed to resolve the correct model version for {cfg_cls.__openllm_start_name__}") - except Exception as err: - raise OpenLLMException(f"Failed to generate a valid tag for {cfg_cls.__openllm_start_name__} with 'model_id={model_id}' (lookup to see its traceback):\n{err}") from err + # quantization setup + if quantization_config and quantize: raise ValueError("'quantization_config' and 'quantize' are mutually exclusive. Either customise your quantization_config or use the 'quantize' argument.") + if quantization_config is None and quantize is not None: quantization_config, attrs = infer_quantisation_config(cls, quantize, **attrs) + if quantize == "gptq": serialisation = "safetensors" + elif cls.__llm_implementation__ == "vllm": serialisation = "legacy" # Currently working-in-progress - return cls( - *args, - model_id=model_id, - llm_config=llm_config, - quantization_config=quantization_config, - bettertransformer=str(first_not_none(bettertransformer, cfg_cls.__openllm_env__["bettertransformer_value"], default=None)).upper() in ENV_VARS_TRUE_VALUES, - _runtime=first_not_none(runtime, cfg_cls.__openllm_env__["runtime_value"], default=cfg_cls.__openllm_runtime__), - _adapters_mapping=resolve_peft_config_type(adapter_map) if adapter_map is not None else None, - _quantize_method=quantize, - _model_version=_tag.version, - _tag=_tag, - _serialisation_format=serialisation, - **attrs, - ) - @classmethod - @functools.lru_cache - @apply(str.lower) - def _generate_tag_str(cls, model_id: str, model_version: str | None) -> str: - """Generate a compliant ``bentoml.Tag`` from model_id. + # NOTE: LoRA adapter setup + if adapter_map and adapter_id: raise ValueError("'adapter_map' and 'adapter_id' are mutually exclusive. Either provide a 'adapter_map' ({adapter_id: adapter_name | None, ...}) or use the combination of adapter_id/adapter_name arguments. ") + if adapter_map is None and adapter_id is not None: adapter_map = {adapter_id: adapter_name} + if adapter_map is not None and not is_peft_available(): raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'") + if adapter_map: logger.debug("OpenLLM will apply the following adapters layers: %s", list(adapter_map)) - If model_id is a pretrained_id from HF, then it will have the following format: -: - If model_id contains the revision itself, then the same format above - If model_id is a path, then it will be -: if model_version is not passesd, otherwise -: + if llm_config is None: + llm_config = cls.config_class.model_construct_env(**attrs) + # The rests of the kwargs that is not used by the config class should be stored into __openllm_extras__. + attrs = llm_config["extras"] - **Note** here that the generated SHA1 for path cases is that it will be based on last modified time. + try: + _tag = cls.generate_tag(model_id, model_version) + if _tag.version is None: raise ValueError(f"Failed to resolve the correct model version for {cfg_cls.__openllm_start_name__}") + except Exception as err: raise OpenLLMException(f"Failed to generate a valid tag for {cfg_cls.__openllm_start_name__} with 'model_id={model_id}' (lookup to see its traceback):\n{err}") from err - Args: - model_id: Model id for this given LLM. It can be pretrained weights URL, custom path. - model_version: Specific revision for this model_id or custom version. + return cls( + *args, model_id=model_id, llm_config=llm_config, quantization_config=quantization_config, bettertransformer=str(first_not_none(bettertransformer, cfg_cls.__openllm_env__["bettertransformer_value"], default=None)).upper() in ENV_VARS_TRUE_VALUES, _runtime=first_not_none(runtime, cfg_cls.__openllm_env__["runtime_value"], default=cfg_cls.__openllm_runtime__), + _adapters_mapping=resolve_peft_config_type(adapter_map) if adapter_map is not None else None, _quantize_method=quantize, _model_version=_tag.version, _tag=_tag, _serialisation_format=serialisation, **attrs + ) - Returns: - ``str``: Generated tag format that can be parsed by ``bentoml.Tag`` - """ - # specific branch for running in docker, this is very hacky, needs change upstream - if in_docker() and os.getenv("BENTO_PATH") is not None: return ":".join(fs.path.parts(model_id)[-2:]) + @classmethod + @functools.lru_cache + @apply(str.lower) + def _generate_tag_str(cls, model_id: str, model_version: str | None) -> str: + """Generate a compliant ``bentoml.Tag`` from model_id. - model_name = normalise_model_name(model_id) - model_id, *maybe_revision = model_id.rsplit(":") - if len(maybe_revision) > 0: - if model_version is not None: logger.warning("revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.", maybe_revision[0], model_version) - return f"{cls.__llm_implementation__}-{model_name}:{maybe_revision[0]}" + If model_id is a pretrained_id from HF, then it will have the following format: -: + If model_id contains the revision itself, then the same format above + If model_id is a path, then it will be -: if model_version is not passesd, otherwise -: - tag_name = f"{cls.__llm_implementation__}-{model_name}" - if os.getenv("OPENLLM_USE_LOCAL_LATEST", str(False)).upper() in ENV_VARS_TRUE_VALUES: return bentoml_cattr.unstructure(bentoml.models.get(f"{tag_name}{':'+model_version if model_version is not None else ''}").tag) - if validate_is_path(model_id): model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id)) - else: - _config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=cls.config_class.__openllm_trust_remote_code__, revision=first_not_none(model_version, default="main")) - model_version = getattr(_config, "_commit_hash", None) - if model_version is None: raise ValueError(f"Internal errors when parsing config for pretrained '{model_id}' ('commit_hash' not found)") - return f"{tag_name}:{model_version}" - @classmethod - def generate_tag(cls, *param_decls: t.Any, **attrs: t.Any) -> bentoml.Tag: return bentoml.Tag.from_taglike(cls._generate_tag_str(*param_decls, **attrs)) - def __init__( + **Note** here that the generated SHA1 for path cases is that it will be based on last modified time. + + Args: + model_id: Model id for this given LLM. It can be pretrained weights URL, custom path. + model_version: Specific revision for this model_id or custom version. + + Returns: + ``str``: Generated tag format that can be parsed by ``bentoml.Tag`` + """ + # specific branch for running in docker, this is very hacky, needs change upstream + if in_docker() and os.getenv("BENTO_PATH") is not None: return ":".join(fs.path.parts(model_id)[-2:]) + + model_name = normalise_model_name(model_id) + model_id, *maybe_revision = model_id.rsplit(":") + if len(maybe_revision) > 0: + if model_version is not None: logger.warning("revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.", maybe_revision[0], model_version) + return f"{cls.__llm_implementation__}-{model_name}:{maybe_revision[0]}" + + tag_name = f"{cls.__llm_implementation__}-{model_name}" + if os.getenv("OPENLLM_USE_LOCAL_LATEST", str(False)).upper() in ENV_VARS_TRUE_VALUES: return bentoml_cattr.unstructure(bentoml.models.get(f"{tag_name}{':'+model_version if model_version is not None else ''}").tag) + if validate_is_path(model_id): model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id)) + else: + _config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=cls.config_class.__openllm_trust_remote_code__, revision=first_not_none(model_version, default="main")) + model_version = getattr(_config, "_commit_hash", None) + if model_version is None: raise ValueError(f"Internal errors when parsing config for pretrained '{model_id}' ('commit_hash' not found)") + return f"{tag_name}:{model_version}" + + @classmethod + def generate_tag(cls, *param_decls: t.Any, **attrs: t.Any) -> bentoml.Tag: return bentoml.Tag.from_taglike(cls._generate_tag_str(*param_decls, **attrs)) + def __init__( + self, *args: t.Any, model_id: str, llm_config: LLMConfig, bettertransformer: bool | None, quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None, _adapters_mapping: AdaptersMapping | None, _tag: bentoml.Tag, _quantize_method: t.Literal["int8", "int4", "gptq"] | None, + _runtime: t.Literal["ggml", "transformers"], _model_version: str, _serialisation_format: t.Literal["safetensors", "legacy"], **attrs: t.Any, + ): + """Initialize the LLM with given pretrained model. + + > **Warning** + > To initializing any LLM, you should use `openllm.AutoLLM` or `openllm.LLM.from_pretrained` instead. + > `__init__` initialization is only for internal use. + + Note: + - *args to be passed to the model. + - **attrs will first be parsed to the AutoConfig, then the rest will be parsed to the import_model + - for tokenizer kwargs, it should be prefixed with _tokenizer_* + + For custom pretrained path, it is recommended to pass in 'model_version' alongside with the path + to ensure that it won't be loaded multiple times. + Internally, if a pretrained is given as a HuggingFace repository path , OpenLLM will usethe commit_hash + to generate the model version. + + For better consistency, we recommend users to also push the fine-tuned model to HuggingFace repository. + + If you need to overwrite the default ``import_model``, implement the following in your subclass: + + ```python + def import_model( self, *args: t.Any, - model_id: str, - llm_config: LLMConfig, - bettertransformer: bool | None, - quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None, - _adapters_mapping: AdaptersMapping | None, - _tag: bentoml.Tag, - _quantize_method: t.Literal["int8", "int4", "gptq"] | None, - _runtime: t.Literal["ggml", "transformers"], - _model_version: str, - _serialisation_format: t.Literal["safetensors", "legacy"], + trust_remote_code: bool, **attrs: t.Any, ): - """Initialize the LLM with given pretrained model. + _, tokenizer_attrs = self.llm_parameters - > **Warning** - > To initializing any LLM, you should use `openllm.AutoLLM` or `openllm.LLM.from_pretrained` instead. - > `__init__` initialization is only for internal use. - - Note: - - *args to be passed to the model. - - **attrs will first be parsed to the AutoConfig, then the rest will be parsed to the import_model - - for tokenizer kwargs, it should be prefixed with _tokenizer_* - - For custom pretrained path, it is recommended to pass in 'model_version' alongside with the path - to ensure that it won't be loaded multiple times. - Internally, if a pretrained is given as a HuggingFace repository path , OpenLLM will usethe commit_hash - to generate the model version. - - For better consistency, we recommend users to also push the fine-tuned model to HuggingFace repository. - - If you need to overwrite the default ``import_model``, implement the following in your subclass: - - ```python - def import_model( - self, - *args: t.Any, - trust_remote_code: bool, - **attrs: t.Any, - ): - _, tokenizer_attrs = self.llm_parameters - - return bentoml.transformers.save_model( - tag, - transformers.AutoModelForCausalLM.from_pretrained( - self.model_id, device_map="auto", torch_dtype=torch.bfloat16, **attrs - ), - custom_objects={ - "tokenizer": transformers.AutoTokenizer.from_pretrained( - self.model_id, padding_size="left", **tokenizer_attrs - ) - }, - ) - ``` - - If your import model doesn't require customization, you can simply pass in `import_kwargs` - at class level that will be then passed into The default `import_model` implementation. - See ``openllm.DollyV2`` for example. - - ```python - dolly_v2_runner = openllm.Runner( - "dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat16, device_map="cuda" + return bentoml.transformers.save_model( + tag, + transformers.AutoModelForCausalLM.from_pretrained( + self.model_id, device_map="auto", torch_dtype=torch.bfloat16, **attrs + ), + custom_objects={ + "tokenizer": transformers.AutoTokenizer.from_pretrained( + self.model_id, padding_size="left", **tokenizer_attrs + ) + }, ) - ``` + ``` - Note: If you implement your own `import_model`, then `import_kwargs` will be the - base kwargs. You can still override those via ``openllm.Runner``. + If your import model doesn't require customization, you can simply pass in `import_kwargs` + at class level that will be then passed into The default `import_model` implementation. + See ``openllm.DollyV2`` for example. - Note that this tag will be generated based on `self.default_id`. - passed from the __init__ constructor. + ```python + dolly_v2_runner = openllm.Runner( + "dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat16, device_map="cuda" + ) + ``` - ``llm_post_init`` can also be implemented if you need to do any additional - initialization after everything is setup. + Note: If you implement your own `import_model`, then `import_kwargs` will be the + base kwargs. You can still override those via ``openllm.Runner``. - Note: If you need to implement a custom `load_model`, the following is an example from Falcon implementation: + Note that this tag will be generated based on `self.default_id`. + passed from the __init__ constructor. - ```python - def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: - torch_dtype = attrs.pop("torch_dtype", torch.bfloat16) - device_map = attrs.pop("device_map", "auto") + ``llm_post_init`` can also be implemented if you need to do any additional + initialization after everything is setup. - _ref = bentoml.transformers.get(tag) + Note: If you need to implement a custom `load_model`, the following is an example from Falcon implementation: - model = bentoml.transformers.load_model(_ref, device_map=device_map, torch_dtype=torch_dtype, **attrs) - return transformers.pipeline("text-generation", model=model, tokenizer=_ref.custom_objects["tokenizer"]) - ``` + ```python + def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: + torch_dtype = attrs.pop("torch_dtype", torch.bfloat16) + device_map = attrs.pop("device_map", "auto") - Args: - model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used. - llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM - will use `config_class` to construct default configuration. - bettertransformer: Whether to use BetterTransformer with this model. Defaults to False. - quantization_config: ``transformers.BitsAndBytesConfig`` configuration, or 'gptq' denoting this model to be loaded with GPTQ. - *args: The args to be passed to the model. - **attrs: The kwargs to be passed to the model. - """ - # low_cpu_mem_usage is only available for model - # this is helpful on system with low memory to avoid OOM - low_cpu_mem_usage = attrs.pop("low_cpu_mem_usage", True) - if self.__llm_implementation__ == "pt": attrs.update({"low_cpu_mem_usage": low_cpu_mem_usage, "quantization_config": quantization_config}) - model_kwds: DictStrAny = {} - tokenizer_kwds: DictStrAny = {} - if self.import_kwargs is not None: model_kwds, tokenizer_kwds = self.import_kwargs - # parsing tokenizer and model kwargs, as the hierachy is param pass > default - normalized_model_kwds, normalized_tokenizer_kwds = normalize_attrs_to_model_tokenizer_pair(**attrs) - # NOTE: Save the args and kwargs for latter load - self.__attrs_init__(llm_config, quantization_config, model_id, _runtime, args, {**model_kwds, **normalized_model_kwds}, {**tokenizer_kwds, **normalized_tokenizer_kwds}, _tag, _adapters_mapping, _model_version, _quantize_method, _serialisation_format) - # handle trust_remote_code - self.__llm_trust_remote_code__ = self._model_attrs.pop("trust_remote_code", self.config["trust_remote_code"]) + _ref = bentoml.transformers.get(tag) - self.llm_post_init() - # we set it here so that we allow subclass to overwrite bettertransformer in llm_post_init - if bettertransformer is True: self.bettertransformer = bettertransformer - else: non_intrusive_setattr(self, "bettertransformer", self.config["bettertransformer"]) - # If lora is passed, the disable bettertransformer - if _adapters_mapping and self.bettertransformer is True: self.bettertransformer = False - def __setattr__(self, attr: str, value: t.Any) -> None: - if attr in _reserved_namespace: raise ForbiddenAttributeError(f"{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}.") - super().__setattr__(attr, value) - @property - def adapters_mapping(self) -> AdaptersMapping | None: return self._adapters_mapping - @adapters_mapping.setter - def adapters_mapping(self, value: AdaptersMapping) -> None: self._adapters_mapping = value - @property - def __repr_keys__(self) -> set[str]: return {"model_id", "runner_name", "config", "adapters_mapping", "runtime", "tag"} - def __repr_args__(self) -> ReprArgs: - for k in self.__repr_keys__: - if k == "config": yield k, self.config.model_dump(flatten=True) - else: yield k, getattr(self, k) - @property - def model_id(self) -> str: return self._model_id - @property - def runtime(self) -> t.Literal["ggml", "transformers"]: return self._runtime - @property - def runner_name(self) -> str: return f"llm-{self.config['start_name']}-runner" - # NOTE: The section below defines a loose contract with langchain's LLM interface. - @property - def llm_type(self) -> str: return normalise_model_name(self._model_id) - @property - def identifying_params(self) -> DictStrAny: return {"configuration": self.config.model_dump_json().decode(), "model_ids": orjson.dumps(self.config["model_ids"]).decode()} - # llm_parameters are used to store args and attrs for model and tokenizer, as it returns a tuple[tuple[model_args, model_kwargs], tokenizer_kwargs] - @property - def llm_parameters(self) -> tuple[tuple[tuple[t.Any, ...], DictStrAny], DictStrAny]: return (self._model_decls, self._model_attrs), self._tokenizer_attrs - @property - def tag(self) -> bentoml.Tag: return self._tag - # ensure_model_id_exists can be called to save the model to local store - def ensure_model_id_exists(self) -> bentoml.Model: return import_model(self.config["start_name"], model_id=self.model_id, model_version=self._model_version, runtime=self.runtime, implementation=self.__llm_implementation__, quantize=self._quantize_method, serialisation_format=self._serialisation_format) - @property - def _bentomodel(self) -> bentoml.Model: - if self.__llm_bentomodel__ is None: self.__llm_bentomodel__ = serialisation.get(self) - return self.__llm_bentomodel__ - @property - def model(self) -> M: - # Run check for GPU - if self.config["requires_gpu"] and device_count() < 1: raise GpuNotAvailableError(f"{self} only supports running with GPU (None available).") from None - # NOTE: the signature of load_model here is the wrapper under _wrapped_load_model - if self.__llm_model__ is None: self.__llm_model__ = self.load_model(*self._model_decls, **self._model_attrs) - return self.__llm_model__ - @property - def tokenizer(self) -> T: - # NOTE: the signature of load_tokenizer here is the wrapper under _wrapped_load_tokenizer - if self.__llm_tokenizer__ is None: self.__llm_tokenizer__ = self.load_tokenizer(**self._tokenizer_attrs) - return self.__llm_tokenizer__ - def _default_ft_config(self, _adapter_type: AdapterType, inference_mode: bool) -> FineTuneConfig: - strategy = first_not_none(self.config["fine_tune_strategies"].get(_adapter_type), default=FineTuneConfig(adapter_type=t.cast("PeftType", _adapter_type), llm_config_class=self.config_class)) - return strategy.eval() if inference_mode else strategy.train() - def _transpose_adapter_mapping( self, inference_mode: bool = True, use_cache: bool = True) -> ResolvedAdaptersMapping: - if self._adapters_mapping is None: raise ValueError("LoRA mapping is not set up correctly.") - # early out if we already serialized everything. - if use_cache and self.__llm_adapter_map__ is not None: return self.__llm_adapter_map__ - if not use_cache: logger.debug("Adapter mapping resolution will not be cached. This should only be used during training.") - adapter_map: ResolvedAdaptersMapping = {k: {} for k in self._adapters_mapping} - # this is a temporary check to accept the first option name as 'default' - # then we will raise Error when the optional_name is set to None in next iteration. - _converted_first_none = False - for _adapter_type, _adapters_tuples in self._adapters_mapping.items(): - default_config = self._default_ft_config(_adapter_type, inference_mode) - for adapter in _adapters_tuples: - if not adapter.name and _converted_first_none: raise ValueError(f"{self.__class__.__name__} doesn't know how to resolve adapter_name None mapping: {adapter.adapter_id, adapter.config}") - name = adapter.name - if name is None: - _converted_first_none = True - name = "default" - peft_config = default_config.with_config(**adapter.config).to_peft_config() if name == "default" else FineTuneConfig(adapter_type=t.cast("PeftType", _adapter_type), adapter_config=adapter.config, inference_mode=inference_mode, llm_config_class=self.config_class).to_peft_config() - adapter_map[_adapter_type][name] = (peft_config, adapter.adapter_id) - if self.__llm_adapter_map__ is None and use_cache: self.__llm_adapter_map__ = adapter_map - return adapter_map - @requires_dependencies("peft", extra="fine-tune") - def prepare_for_training(self, adapter_type: AdapterType = "lora", use_gradient_checkpointing: bool = True, **attrs: t.Any) -> tuple[peft.PeftModel, T]: - from peft import prepare_model_for_kbit_training - peft_config = self.config["fine_tune_strategies"].get(adapter_type, FineTuneConfig(adapter_type=t.cast("PeftType", adapter_type), llm_config_class=self.config_class)).train().with_config(**attrs).to_peft_config() - wrapped_peft = peft.get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checkpointing), peft_config) - if DEBUG: wrapped_peft.print_trainable_parameters() - return wrapped_peft, self.tokenizer - @requires_dependencies("peft", extra="fine-tune") - def apply_adapter(self, inference_mode: bool = True, adapter_type: AdapterType = "lora", load_adapters: t.Literal["all"] | list[str] | None = None, use_cache: bool = True) -> M: - """Apply given LoRA mapping to the model. + model = bentoml.transformers.load_model(_ref, device_map=device_map, torch_dtype=torch_dtype, **attrs) + return transformers.pipeline("text-generation", model=model, tokenizer=_ref.custom_objects["tokenizer"]) + ``` - Note that the base model can still be accessed via self.model.get_base_model(). - """ - if self.__llm_model__ is None: raise ValueError("Error: Model is not loaded correctly") - # early out if _adapters_mapping is empty or it is already wrapped with peft. - if not self._adapters_mapping: return self.__llm_model__ - if isinstance(self.__llm_model__, peft.PeftModel): return self.__llm_model__ + Args: + model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used. + llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM + will use `config_class` to construct default configuration. + bettertransformer: Whether to use BetterTransformer with this model. Defaults to False. + quantization_config: ``transformers.BitsAndBytesConfig`` configuration, or 'gptq' denoting this model to be loaded with GPTQ. + *args: The args to be passed to the model. + **attrs: The kwargs to be passed to the model. + """ + # low_cpu_mem_usage is only available for model + # this is helpful on system with low memory to avoid OOM + low_cpu_mem_usage = attrs.pop("low_cpu_mem_usage", True) + if self.__llm_implementation__ == "pt": attrs.update({"low_cpu_mem_usage": low_cpu_mem_usage, "quantization_config": quantization_config}) + model_kwds: DictStrAny = {} + tokenizer_kwds: DictStrAny = {} + if self.import_kwargs is not None: model_kwds, tokenizer_kwds = self.import_kwargs + # parsing tokenizer and model kwargs, as the hierachy is param pass > default + normalized_model_kwds, normalized_tokenizer_kwds = normalize_attrs_to_model_tokenizer_pair(**attrs) + # NOTE: Save the args and kwargs for latter load + self.__attrs_init__(llm_config, quantization_config, model_id, _runtime, args, {**model_kwds, **normalized_model_kwds}, {**tokenizer_kwds, **normalized_tokenizer_kwds}, _tag, _adapters_mapping, _model_version, _quantize_method, _serialisation_format) + # handle trust_remote_code + self.__llm_trust_remote_code__ = self._model_attrs.pop("trust_remote_code", self.config["trust_remote_code"]) - _mapping = self._transpose_adapter_mapping(inference_mode=inference_mode, use_cache=use_cache) - if adapter_type not in _mapping: raise ValueError(f"Given adapter type {adapter_type} is not supported. Please choose from {list(_mapping.keys())}") - adapter_mapping = _mapping[adapter_type] + self.llm_post_init() + # we set it here so that we allow subclass to overwrite bettertransformer in llm_post_init + if bettertransformer is True: self.bettertransformer = bettertransformer + else: non_intrusive_setattr(self, "bettertransformer", self.config["bettertransformer"]) + # If lora is passed, the disable bettertransformer + if _adapters_mapping and self.bettertransformer is True: self.bettertransformer = False - self.__llm_model__ = self._wrap_default_peft_model(adapter_mapping, inference_mode=inference_mode) + def __setattr__(self, attr: str, value: t.Any) -> None: + if attr in _reserved_namespace: raise ForbiddenAttributeError(f"{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}.") + super().__setattr__(attr, value) - # now we loop through the rest with add_adapter - if len(adapter_mapping) > 0: - for adapter_name, (_peft_config, _) in adapter_mapping.items(): self.__llm_model__.add_adapter(adapter_name, _peft_config) + @property + def adapters_mapping(self) -> AdaptersMapping | None: return self._adapters_mapping + @adapters_mapping.setter + def adapters_mapping(self, value: AdaptersMapping) -> None: self._adapters_mapping = value - # optionally load adapters. In case of multiple adapters, or on Runner, - # we will need to set load_adapters='all' - if load_adapters is not None: - adapters_to_load = adapter_mapping.keys() if load_adapters == "all" else load_adapters - for adapter_name in adapters_to_load: - _peft_config, _peft_model_id = adapter_mapping[adapter_name] - self.__llm_model__.load_adapter(_peft_model_id, adapter_name=adapter_name, is_trainable=not inference_mode, **dict(_peft_config.to_dict())) + @property + def __repr_keys__(self) -> set[str]: return {"model_id", "runner_name", "config", "adapters_mapping", "runtime", "tag"} + def __repr_args__(self) -> ReprArgs: + for k in self.__repr_keys__: + if k == "config": yield k, self.config.model_dump(flatten=True) + else: yield k, getattr(self, k) - return self.__llm_model__ - def _wrap_default_peft_model(self, adapter_mapping: dict[str, tuple[peft.PeftConfig, str]], inference_mode: bool) -> M: - if self.__llm_model__ is None: raise ValueError("Error: Model is not loaded correctly") - if isinstance(self.__llm_model__, peft.PeftModel): return self.__llm_model__ - if not isinstance(self.__llm_model__, transformers.PreTrainedModel): raise ValueError("Loading LoRA layers currently only runs on PyTorch models.") + @property + def model_id(self) -> str: return self._model_id + @property + def runtime(self) -> t.Literal["ggml", "transformers"]: return self._runtime + @property + def runner_name(self) -> str: return f"llm-{self.config['start_name']}-runner" - if "default" not in adapter_mapping: raise ValueError("There is no 'default' mapping. Please check the adapter mapping and report this bug to the OpenLLM team.") - default_config, peft_model_id = adapter_mapping.pop("default") + # NOTE: The section below defines a loose contract with langchain's LLM interface. + @property + def llm_type(self) -> str: return normalise_model_name(self._model_id) + @property + def identifying_params(self) -> DictStrAny: return {"configuration": self.config.model_dump_json().decode(), "model_ids": orjson.dumps(self.config["model_ids"]).decode()} + @property + def llm_parameters(self) -> tuple[tuple[tuple[t.Any, ...], DictStrAny], DictStrAny]: return (self._model_decls, self._model_attrs), self._tokenizer_attrs - # the below shared similar logics with `get_peft_model` - # TODO: Support PromptLearningConfig - if default_config.task_type not in peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(default_config, peft.PromptLearningConfig): - logger.debug("Given task type '%s' is not supported by peft. Make sure the adapter is loaded manually before running inference.", default_config.task_type) - model = peft.PeftModel(self.__llm_model__, default_config) - else: - # XXX: this is not ideal to serialize like this, maybe for fine-tune we will only support 0.4.0 - # onwards. For now, keep this logic here. - peft_class = peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING[default_config.task_type] - if default_config.base_model_name_or_path: - kwargs: DictStrAny = {"is_trainable": not inference_mode} - if "config" in inspect.signature(peft_class.from_pretrained).parameters: kwargs["config"] = default_config - else: kwargs.update(dict(default_config.to_dict().items())) - # BUG: This hits during inference, need fixing - model = peft_class.from_pretrained(self.__llm_model__, peft_model_id, **kwargs) - else: model = peft_class(self.__llm_model__, default_config) # in this case, the given base_model_name_or_path is None. This will be hit during training - return model - # order of these fields matter here, make sure to sync it with - # openllm.models.auto.factory.BaseAutoLLMClass.for_model - def to_runner(self, models: list[bentoml.Model] | None = None, max_batch_size: int | None = None, max_latency_ms: int | None = None, scheduling_strategy: type[bentoml.Strategy] | None = None) -> LLMRunner[M, T]: - """Convert this LLM into a Runner. + @property + def tag(self) -> bentoml.Tag: return self._tag + # ensure_model_id_exists can be called to save the model to local store + def ensure_model_id_exists(self) -> bentoml.Model: return import_model(self.config["start_name"], model_id=self.model_id, model_version=self._model_version, runtime=self.runtime, implementation=self.__llm_implementation__, quantize=self._quantize_method, serialisation_format=self._serialisation_format) + @property + def _bentomodel(self) -> bentoml.Model: + if self.__llm_bentomodel__ is None: self.__llm_bentomodel__ = serialisation.get(self) + return self.__llm_bentomodel__ - Args: - models: Any additional ``bentoml.Model`` to be included in this given models. - By default, this will be determined from the model_name. - max_batch_size: The maximum batch size for the runner. - max_latency_ms: The maximum latency for the runner. - strategy: The strategy to use for this runner. - embedded: Whether to run this runner in embedded mode. - scheduling_strategy: Whether to create a custom scheduling strategy for this Runner. + @property + def model(self) -> M: + # Run check for GPU + if self.config["requires_gpu"] and device_count() < 1: raise GpuNotAvailableError(f"{self} only supports running with GPU (None available).") from None + # NOTE: the signature of load_model here is the wrapper under _wrapped_load_model + if self.__llm_model__ is None: self.__llm_model__ = self.load_model(*self._model_decls, **self._model_attrs) + return self.__llm_model__ - Returns: - A generated LLMRunner for this LLM. + @property + def tokenizer(self) -> T: + # NOTE: the signature of load_tokenizer here is the wrapper under _wrapped_load_tokenizer + if self.__llm_tokenizer__ is None: self.__llm_tokenizer__ = self.load_tokenizer(**self._tokenizer_attrs) + return self.__llm_tokenizer__ - NOTE: There are some difference between bentoml.models.get().to_runner() and LLM.to_runner(): 'name'. - - 'name': will be generated by OpenLLM, hence users don't shouldn't worry about this. - The generated name will be 'llm--runner' (ex: llm-dolly-v2-runner, llm-chatglm-runner) - - 'embedded': Will be disabled by default. There is no reason to run LLM in embedded mode. - - 'method_configs': The method configs for the runner will be managed internally by OpenLLM. - """ - models = models if models is not None else [] + def _default_ft_config(self, _adapter_type: AdapterType, inference_mode: bool) -> FineTuneConfig: + strategy = first_not_none(self.config["fine_tune_strategies"].get(_adapter_type), default=FineTuneConfig(adapter_type=t.cast("PeftType", _adapter_type), llm_config_class=self.config_class)) + return strategy.eval() if inference_mode else strategy.train() - try: models.append(self._bentomodel) - except bentoml.exceptions.NotFound: models.append(serialisation.get(self, auto_import=True)) + def _transpose_adapter_mapping(self, inference_mode: bool = True, use_cache: bool = True) -> ResolvedAdaptersMapping: + if self._adapters_mapping is None: raise ValueError("LoRA mapping is not set up correctly.") + # early out if we already serialized everything. + if use_cache and self.__llm_adapter_map__ is not None: return self.__llm_adapter_map__ + if not use_cache: logger.debug("Adapter mapping resolution will not be cached. This should only be used during training.") + adapter_map: ResolvedAdaptersMapping = {k: {} for k in self._adapters_mapping} + # this is a temporary check to accept the first option name as 'default' + # then we will raise Error when the optional_name is set to None in next iteration. + _converted_first_none = False + for _adapter_type, _adapters_tuples in self._adapters_mapping.items(): + default_config = self._default_ft_config(_adapter_type, inference_mode) + for adapter in _adapters_tuples: + if not adapter.name and _converted_first_none: raise ValueError(f"{self.__class__.__name__} doesn't know how to resolve adapter_name None mapping: {adapter.adapter_id, adapter.config}") + name = adapter.name + if name is None: + _converted_first_none = True + name = "default" + peft_config = default_config.with_config(**adapter.config).to_peft_config() if name == "default" else FineTuneConfig(adapter_type=t.cast("PeftType", _adapter_type), adapter_config=adapter.config, inference_mode=inference_mode, llm_config_class=self.config_class).to_peft_config() + adapter_map[_adapter_type][name] = (peft_config, adapter.adapter_id) + if self.__llm_adapter_map__ is None and use_cache: self.__llm_adapter_map__ = adapter_map + return adapter_map - if scheduling_strategy is None: - from ._strategies import CascadingResourceStrategy - scheduling_strategy = CascadingResourceStrategy + @requires_dependencies("peft", extra="fine-tune") + def prepare_for_training(self, adapter_type: AdapterType = "lora", use_gradient_checkpointing: bool = True, **attrs: t.Any) -> tuple[peft.PeftModel, T]: + from peft import prepare_model_for_kbit_training + peft_config = self.config["fine_tune_strategies"].get(adapter_type, FineTuneConfig(adapter_type=t.cast("PeftType", adapter_type), llm_config_class=self.config_class)).train().with_config(**attrs).to_peft_config() + wrapped_peft = peft.get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checkpointing), peft_config) + if DEBUG: wrapped_peft.print_trainable_parameters() + return wrapped_peft, self.tokenizer - generate_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=False))) - embeddings_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=False))) - generate_iterator_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=True))) + @requires_dependencies("peft", extra="fine-tune") + def apply_adapter(self, inference_mode: bool = True, adapter_type: AdapterType = "lora", load_adapters: t.Literal["all"] | list[str] | None = None, use_cache: bool = True) -> M: + """Apply given LoRA mapping to the model. Note that the base model can still be accessed via self.model.get_base_model().""" + if self.__llm_model__ is None: raise ValueError("Error: Model is not loaded correctly") + # early out if _adapters_mapping is empty or it is already wrapped with peft. + if not self._adapters_mapping: return self.__llm_model__ + if isinstance(self.__llm_model__, peft.PeftModel): return self.__llm_model__ - # NOTE: returning the two langchain API's to the runner - return llm_runner_class(self)( - llm_runnable_class(self, embeddings_sig, generate_sig, generate_iterator_sig), - name=self.runner_name, - embedded=False, - models=models, - max_batch_size=max_batch_size, - max_latency_ms=max_latency_ms, - method_configs=bentoml_cattr.unstructure({"embeddings": embeddings_sig, "__call__": generate_sig, "generate": generate_sig, "generate_one": generate_sig, "generate_iterator": generate_iterator_sig}), - scheduling_strategy=scheduling_strategy, - ) - def predict(self, prompt: str, **attrs: t.Any) -> t.Any: - """The scikit-compatible API for self(...).""" - return self.__call__(prompt, **attrs) - def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: - """Returns the generation result and format the result. + _mapping = self._transpose_adapter_mapping(inference_mode=inference_mode, use_cache=use_cache) + if adapter_type not in _mapping: raise ValueError(f"Given adapter type {adapter_type} is not supported. Please choose from {list(_mapping.keys())}") + adapter_mapping = _mapping[adapter_type] - First, it runs `self.sanitize_parameters` to sanitize the parameters. - The the sanitized prompt and kwargs will be pass into self.generate. - Finally, run self.postprocess_generate to postprocess the generated result. + self.__llm_model__ = self._wrap_default_peft_model(adapter_mapping, inference_mode=inference_mode) - This allows users to do the following: + # now we loop through the rest with add_adapter + if len(adapter_mapping) > 0: + for adapter_name, (_peft_config, _) in adapter_mapping.items(): self.__llm_model__.add_adapter(adapter_name, _peft_config) - ```python - llm = openllm.AutoLLM.for_model("dolly-v2") - llm("What is the meaning of life?") - ``` - """ - prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs) - return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs) + # optionally load adapters. In case of multiple adapters, or on Runner, + # we will need to set load_adapters='all' + if load_adapters is not None: + adapters_to_load = adapter_mapping.keys() if load_adapters == "all" else load_adapters + for adapter_name in adapters_to_load: + _peft_config, _peft_model_id = adapter_mapping[adapter_name] + self.__llm_model__.load_adapter(_peft_model_id, adapter_name=adapter_name, is_trainable=not inference_mode, **dict(_peft_config.to_dict())) + return self.__llm_model__ + + def _wrap_default_peft_model(self, adapter_mapping: dict[str, tuple[peft.PeftConfig, str]], inference_mode: bool) -> M: + if self.__llm_model__ is None: raise ValueError("Error: Model is not loaded correctly") + if isinstance(self.__llm_model__, peft.PeftModel): return self.__llm_model__ + if not isinstance(self.__llm_model__, transformers.PreTrainedModel): raise ValueError("Loading LoRA layers currently only runs on PyTorch models.") + + if "default" not in adapter_mapping: raise ValueError("There is no 'default' mapping. Please check the adapter mapping and report this bug to the OpenLLM team.") + default_config, peft_model_id = adapter_mapping.pop("default") + + # the below shared similar logics with `get_peft_model` + # TODO: Support PromptLearningConfig + if default_config.task_type not in peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(default_config, peft.PromptLearningConfig): + logger.debug("Given task type '%s' is not supported by peft. Make sure the adapter is loaded manually before running inference.", default_config.task_type) + model = peft.PeftModel(self.__llm_model__, default_config) + else: + # XXX: this is not ideal to serialize like this, maybe for fine-tune we will only support 0.4.0 + # onwards. For now, keep this logic here. + peft_class = peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING[default_config.task_type] + if default_config.base_model_name_or_path: + kwargs: DictStrAny = {"is_trainable": not inference_mode} + if "config" in inspect.signature(peft_class.from_pretrained).parameters: kwargs["config"] = default_config + else: kwargs.update(dict(default_config.to_dict().items())) + # BUG: This hits during inference, need fixing + model = peft_class.from_pretrained(self.__llm_model__, peft_model_id, **kwargs) + else: model = peft_class(self.__llm_model__, default_config) # in this case, the given base_model_name_or_path is None. This will be hit during training + return model + + # order of these fields matter here, make sure to sync it with + # openllm.models.auto.factory.BaseAutoLLMClass.for_model + def to_runner(self, models: list[bentoml.Model] | None = None, max_batch_size: int | None = None, max_latency_ms: int | None = None, scheduling_strategy: type[bentoml.Strategy] | None = None) -> LLMRunner[M, T]: + """Convert this LLM into a Runner. + + Args: + models: Any additional ``bentoml.Model`` to be included in this given models. + By default, this will be determined from the model_name. + max_batch_size: The maximum batch size for the runner. + max_latency_ms: The maximum latency for the runner. + strategy: The strategy to use for this runner. + embedded: Whether to run this runner in embedded mode. + scheduling_strategy: Whether to create a custom scheduling strategy for this Runner. + + Returns: + A generated LLMRunner for this LLM. + + > **Note**: There are some difference between bentoml.models.get().to_runner() and LLM.to_runner(): 'name'. + - 'name': will be generated by OpenLLM, hence users don't shouldn't worry about this. The generated name will be 'llm--runner' (ex: llm-dolly-v2-runner, llm-chatglm-runner) + - 'embedded': Will be disabled by default. There is no reason to run LLM in embedded mode. + - 'method_configs': The method configs for the runner will be managed internally by OpenLLM. + """ + models = models if models is not None else [] + + try: models.append(self._bentomodel) + except bentoml.exceptions.NotFound: models.append(serialisation.get(self, auto_import=True)) + + if scheduling_strategy is None: + from ._strategies import CascadingResourceStrategy + scheduling_strategy = CascadingResourceStrategy + + generate_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=False))) + embeddings_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=False))) + generate_iterator_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=True))) + + # NOTE: returning the two langchain API's to the runner + return llm_runner_class(self)( + llm_runnable_class(self, embeddings_sig, generate_sig, generate_iterator_sig), name=self.runner_name, embedded=False, models=models, max_batch_size=max_batch_size, max_latency_ms=max_latency_ms, + method_configs=bentoml_cattr.unstructure({"embeddings": embeddings_sig, "__call__": generate_sig, "generate": generate_sig, "generate_one": generate_sig, "generate_iterator": generate_iterator_sig}), scheduling_strategy=scheduling_strategy, + ) + + # NOTE: Scikit API + def predict(self, prompt: str, **attrs: t.Any) -> t.Any: return self.__call__(prompt, **attrs) + def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: + """Returns the generation result and format the result. + + First, it runs `self.sanitize_parameters` to sanitize the parameters. + The the sanitized prompt and kwargs will be pass into self.generate. + Finally, run self.postprocess_generate to postprocess the generated result. + + This allows users to do the following: + + ```python + llm = openllm.AutoLLM.for_model("dolly-v2") + llm("What is the meaning of life?") + ``` + """ + prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs) + return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs) + +# fmt: off @overload def Runner(model_name: str, *, model_id: str | None = None, model_version: str | None = ..., init_local: t.Literal[False, True] = ..., **attrs: t.Any) -> LLMRunner[t.Any, t.Any]: ... @overload -def Runner(model_name: str, *, model_id: str = ..., model_version: str | None = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, ModelSignatureDict | ModelSignature] | None = ..., embedded: t.Literal[True, False] = ..., scheduling_strategy: type[bentoml.Strategy] | None = ..., **attrs: t.Any) -> LLMRunner[t.Any, t.Any]: ... +def Runner( + model_name: str, *, model_id: str = ..., model_version: str | None = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, ModelSignatureDict | ModelSignature] | None = ..., embedded: t.Literal[True, False] = ..., scheduling_strategy: type[bentoml.Strategy] | None = ..., **attrs: t.Any +) -> LLMRunner[t.Any, t.Any]: ... @overload def Runner(model_name: str, *, ensure_available: bool | None = None, init_local: bool = ..., implementation: LiteralRuntime | None = None, llm_config: LLMConfig | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]: ... @overload -def Runner(model_name: str, *, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["int8", "int4", "gptq"] | None = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ..., quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None = None, serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any) -> LLMRunner[t.Any, t.Any]: ... +def Runner( + model_name: str, *, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["int8", "int4", "gptq"] | None = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., + adapter_map: dict[str, str | None] | None = ..., quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None = None, serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any +) -> LLMRunner[t.Any, t.Any]: ... +# fmt: on + def Runner(model_name: str, ensure_available: bool | None = None, init_local: bool = False, implementation: LiteralRuntime | None = None, llm_config: LLMConfig | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]: - """Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'. + """Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'. - The behaviour of ensure_available that is synonymous to `AutoLLM.for_model` depends on `init_local`. - By default, `ensure_available` is synonymous to `init_local`, meaning on the service when creating - runner, it won't download the model. So before running your BentoML Service, you should create a `on_startup` - hook to check download if you don't want to do it manually: + The behaviour of ensure_available that is synonymous to `AutoLLM.for_model` depends on `init_local`. + By default, `ensure_available` is synonymous to `init_local`, meaning on the service when creating + runner, it won't download the model. So before running your BentoML Service, you should create a `on_startup` + hook to check download if you don't want to do it manually: - ```python + ```python - runner = openllm.Runner("dolly-v2") + runner = openllm.Runner("dolly-v2") - @svc.on_startup - def download(): - runner.download_model() - ``` + @svc.on_startup + def download(): + runner.download_model() + ``` - if `init_local=True` (For development workflow), it will also enable `ensure_available`. - Default value of `ensure_available` is None. If set then use that given value, otherwise fallback to the aforementioned behaviour. + if `init_local=True` (For development workflow), it will also enable `ensure_available`. + Default value of `ensure_available` is None. If set then use that given value, otherwise fallback to the aforementioned behaviour. - Args: - model_name: Supported model name from 'openllm models' - ensure_available: If True, it will download the model if it is not available. If False, it will skip downloading the model. - If False, make sure the model is available locally. - implementation: The given Runner implementation one choose for this Runner. By default, it is retrieved from the enviroment variable - of the respected model_name. For example: 'flan-t5' -> "OPENLLM_FLAN_T5_FRAMEWORK" - llm_config: Optional ``openllm.LLMConfig`` to initialise this ``openllm.LLMRunner``. - init_local: If True, it will initialize the model locally. This is useful if you want to - run the model locally. (Symmetrical to bentoml.Runner.init_local()) - **attrs: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs - behaviour - """ - if llm_config is not None: - attrs.update( - { - "model_id": llm_config["env"]["model_id_value"], - "bettertransformer": llm_config["env"]["bettertransformer_value"], - "quantize": llm_config["env"]["quantize_value"], - "runtime": llm_config["env"]["runtime_value"], - "serialisation": first_not_none(os.getenv("OPENLLM_SERIALIZATION"), attrs.get("serialisation"), default="safetensors"), - } - ) + Args: + model_name: Supported model name from 'openllm models' + ensure_available: If True, it will download the model if it is not available. If False, it will skip downloading the model. + If False, make sure the model is available locally. + implementation: The given Runner implementation one choose for this Runner. By default, it is retrieved from the enviroment variable + of the respected model_name. For example: 'flan-t5' -> "OPENLLM_FLAN_T5_FRAMEWORK" + llm_config: Optional ``openllm.LLMConfig`` to initialise this ``openllm.LLMRunner``. + init_local: If True, it will initialize the model locally. This is useful if you want to + run the model locally. (Symmetrical to bentoml.Runner.init_local()) + **attrs: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs + behaviour + """ + if llm_config is not None: + attrs.update({"model_id": llm_config["env"]["model_id_value"], "bettertransformer": llm_config["env"]["bettertransformer_value"], "quantize": llm_config["env"]["quantize_value"], "runtime": llm_config["env"]["runtime_value"], "serialisation": first_not_none(os.getenv("OPENLLM_SERIALIZATION"), attrs.get("serialisation"), default="safetensors"),}) - default_implementation = llm_config.default_implementation() if llm_config is not None else "pt" - implementation = first_not_none(implementation, default=EnvVarMixin(model_name, default_implementation)["framework_value"]) - runner = infer_auto_class(implementation).create_runner(model_name, llm_config=llm_config, ensure_available=ensure_available if ensure_available is not None else init_local, **attrs) - if init_local: runner.init_local(quiet=True) - return runner - - -def method_signature(sig: ModelSignature) -> ModelSignatureDict: return bentoml_cattr.unstructure(sig) + default_implementation = llm_config.default_implementation() if llm_config is not None else "pt" + implementation = first_not_none(implementation, default=EnvVarMixin(model_name, default_implementation)["framework_value"]) + runner = infer_auto_class(implementation).create_runner(model_name, llm_config=llm_config, ensure_available=ensure_available if ensure_available is not None else init_local, **attrs) + if init_local: runner.init_local(quiet=True) + return runner +def method_signature(sig: ModelSignature) -> ModelSignatureDict: + return bentoml_cattr.unstructure(sig) class SetAdapterOutput(t.TypedDict): - success: bool - message: str - + success: bool + message: str def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate_sig: ModelSignature, generate_iterator_sig: ModelSignature) -> type[LLMRunnable[M, T]]: - class _Runnable(bentoml.Runnable): - SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu", "cpu") - SUPPORTS_CPU_MULTI_THREADING = True + class _Runnable(bentoml.Runnable): + SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu", "cpu") + SUPPORTS_CPU_MULTI_THREADING = True - def __init__(__self: _Runnable): - # NOTE: The side effect of this line - # is that it will load the imported model during - # runner startup. So don't remove it!! - if not self.model: raise RuntimeError("Failed to load the model correctly (See traceback above)") - if self.adapters_mapping is not None: - logger.info("Applying LoRA to %s...", self.runner_name) - self.apply_adapter(inference_mode=True, load_adapters="all") - @requires_dependencies("peft", extra="fine-tune") - def set_adapter(__self: _Runnable, adapter_name: str) -> None: - if self.__llm_adapter_map__ is None: raise ValueError("No adapters available for current running server.") - elif not isinstance(self.model, peft.PeftModel): raise RuntimeError("Model is not a PeftModel") - if adapter_name != "default": self.model.set_adapter(adapter_name) - logger.info("Successfully apply LoRA layer %s", adapter_name) - @bentoml.Runnable.method(**method_signature(embeddings_sig)) - def embeddings(__self: _Runnable, prompt: str | list[str]) -> LLMEmbeddings: return self.embeddings([prompt] if isinstance(prompt, str) else prompt) - @bentoml.Runnable.method(**method_signature(generate_sig)) - def __call__(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: - adapter_name = attrs.pop("adapter_name", None) - if adapter_name is not None: __self.set_adapter(adapter_name) - return self.generate(prompt, **attrs) - @bentoml.Runnable.method(**method_signature(generate_sig)) - def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: - adapter_name = attrs.pop("adapter_name", None) - if adapter_name is not None: __self.set_adapter(adapter_name) - return self.generate(prompt, **attrs) - @bentoml.Runnable.method(**method_signature(generate_sig)) - def generate_one(__self: _Runnable, prompt: str, stop: list[str], **attrs: t.Any) -> t.Sequence[dict[t.Literal["generated_text"], str]]: - adapter_name = attrs.pop("adapter_name", None) - if adapter_name is not None: __self.set_adapter(adapter_name) - return self.generate_one(prompt, stop, **attrs) - @bentoml.Runnable.method(**method_signature(generate_iterator_sig)) - def generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.Generator[t.Any, None, None]: - adapter_name = attrs.pop("adapter_name", None) - if adapter_name is not None: __self.set_adapter(adapter_name) - yield self.generate_iterator(prompt, **attrs) + def __init__(__self: _Runnable): + # NOTE: The side effect of this line + # is that it will load the imported model during + # runner startup. So don't remove it!! + if not self.model: raise RuntimeError("Failed to load the model correctly (See traceback above)") + if self.adapters_mapping is not None: + logger.info("Applying LoRA to %s...", self.runner_name) + self.apply_adapter(inference_mode=True, load_adapters="all") - return types.new_class( - self.__class__.__name__ + "Runnable", - (_Runnable,), - {}, - lambda ns: ns.update( - { - "SUPPORTED_RESOURCES": ("nvidia.com/gpu", "amd.com/gpu") - if self.config["requires_gpu"] - else ("nvidia.com/gpu", "amd.com/gpu", "cpu"), - "__module__": self.__module__, - "__doc__": self.config["env"].start_docstring, - } - ), - ) + @requires_dependencies("peft", extra="fine-tune") + def set_adapter(__self: _Runnable, adapter_name: str) -> None: + if self.__llm_adapter_map__ is None: raise ValueError("No adapters available for current running server.") + elif not isinstance(self.model, peft.PeftModel): raise RuntimeError("Model is not a PeftModel") + if adapter_name != "default": self.model.set_adapter(adapter_name) + logger.info("Successfully apply LoRA layer %s", adapter_name) + + @bentoml.Runnable.method(**method_signature(embeddings_sig)) + def embeddings(__self: _Runnable, prompt: str | list[str]) -> LLMEmbeddings: + return self.embeddings([prompt] if isinstance(prompt, str) else prompt) + + @bentoml.Runnable.method(**method_signature(generate_sig)) + def __call__(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: + adapter_name = attrs.pop("adapter_name", None) + if adapter_name is not None: __self.set_adapter(adapter_name) + return self.generate(prompt, **attrs) + + @bentoml.Runnable.method(**method_signature(generate_sig)) + def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: + adapter_name = attrs.pop("adapter_name", None) + if adapter_name is not None: __self.set_adapter(adapter_name) + return self.generate(prompt, **attrs) + + @bentoml.Runnable.method(**method_signature(generate_sig)) + def generate_one(__self: _Runnable, prompt: str, stop: list[str], **attrs: t.Any) -> t.Sequence[dict[t.Literal["generated_text"], str]]: + adapter_name = attrs.pop("adapter_name", None) + if adapter_name is not None: __self.set_adapter(adapter_name) + return self.generate_one(prompt, stop, **attrs) + + @bentoml.Runnable.method(**method_signature(generate_iterator_sig)) + def generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.Generator[t.Any, None, None]: + adapter_name = attrs.pop("adapter_name", None) + if adapter_name is not None: __self.set_adapter(adapter_name) + yield self.generate_iterator(prompt, **attrs) + + return types.new_class(self.__class__.__name__ + "Runnable", (_Runnable,), {}, lambda ns: ns.update({"SUPPORTED_RESOURCES": ("nvidia.com/gpu", "amd.com/gpu") if self.config["requires_gpu"] else ("nvidia.com/gpu", "amd.com/gpu", "cpu"), "__module__": self.__module__, "__doc__": self.config["env"].start_docstring})) def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]: - def available_adapters(_: LLMRunner[M, T]) -> PeftAdapterOutput: - if not is_peft_available(): return {"success": False, "result": {}, "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'"} - if self.__llm_adapter_map__ is None: return {"success": False, "result": {}, "error_msg": "No adapters available for current running server."} - if not isinstance(self.model, peft.PeftModel): return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"} - return {"success": True, "result": self.model.peft_config, "error_msg": ""} - def _wrapped_generate_run(__self: LLMRunner[M, T], prompt: str, **kwargs: t.Any) -> t.Any: - """Wrapper for runner.generate.run() to handle the prompt and postprocessing. + def available_adapters(_: LLMRunner[M, T]) -> PeftAdapterOutput: + if not is_peft_available(): return {"success": False, "result": {}, "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'"} + if self.__llm_adapter_map__ is None: return {"success": False, "result": {}, "error_msg": "No adapters available for current running server."} + if not isinstance(self.model, peft.PeftModel): return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"} + return {"success": True, "result": self.model.peft_config, "error_msg": ""} + def _wrapped_generate_run(__self: LLMRunner[M, T], prompt: str, **kwargs: t.Any) -> t.Any: + """Wrapper for runner.generate.run() to handle the prompt and postprocessing. - This will be used for LangChain API. + This will be used for LangChain API. - Usage: - ```python - runner = openllm.Runner("dolly-v2", init_local=True) - runner("What is the meaning of life?") - ``` - """ - prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **kwargs) - return self.postprocess_generate(prompt, __self.generate.run(prompt, **generate_kwargs), **postprocess_kwargs) - def _wrapped_embeddings_run(__self: LLMRunner[M, T], prompt: str | list[str]) -> LLMEmbeddings: - """``llm.embed`` is a light wrapper around runner.embeedings.run(). + Usage: - Usage: - ```python - runner = openllm.Runner('llama', implementation='pt') - runner.embed("What is the meaning of life?") - ``` - """ - return __self.embeddings.run([prompt] if isinstance(prompt, str) else prompt) - def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]: return {"config", "llm_type", "runner_methods", "runtime", "llm_tag"} - def _wrapped_repr_args(__self: LLMRunner[M, T]) -> ReprArgs: - yield "runner_methods", {method.name: {"batchable": method.config.batchable, "batch_dim": method.config.batch_dim if method.config.batchable else None} for method in __self.runner_methods} - yield "config", self.config.model_dump(flatten=True) - yield "llm_type", __self.llm_type - yield "runtime", self.runtime - yield "llm_tag", self.tag - return types.new_class( - self.__class__.__name__ + "Runner", - (bentoml.Runner,), - exec_body=lambda ns: ns.update( - { - "llm_type": self.llm_type, - "identifying_params": self.identifying_params, - "llm_tag": self.tag, - "llm": self, # NOTE: self reference to LLM - "config": self.config, - "implementation": self.__llm_implementation__, - "peft_adapters": property(fget=available_adapters), - "download_model": self.ensure_model_id_exists, - "__call__": _wrapped_generate_run, - "embed": _wrapped_embeddings_run, - "__module__": self.__module__, - "__doc__": self.config["env"].start_docstring, - "__repr__": ReprMixin.__repr__, - "__repr_keys__": property(_wrapped_repr_keys), - "__repr_args__": _wrapped_repr_args, - "supports_embeddings": self["supports_embeddings"], - "supports_hf_agent": self["supports_generate_one"], - "has_adapters": self._adapters_mapping is not None, - } - ), - ) + ```python + runner = openllm.Runner("dolly-v2", init_local=True) + runner("What is the meaning of life?") + ``` + """ + prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **kwargs) + return self.postprocess_generate(prompt, __self.generate.run(prompt, **generate_kwargs), **postprocess_kwargs) + def _wrapped_embeddings_run(__self: LLMRunner[M, T], prompt: str | list[str]) -> LLMEmbeddings: + """``llm.embed`` is a light wrapper around runner.embeedings.run(). + + Usage: + + ```python + runner = openllm.Runner('llama', implementation='pt') + runner.embed("What is the meaning of life?") + ``` + """ + return __self.embeddings.run([prompt] if isinstance(prompt, str) else prompt) + def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]: return {"config", "llm_type", "runner_methods", "runtime", "llm_tag"} + def _wrapped_repr_args(__self: LLMRunner[M, T]) -> ReprArgs: + yield "runner_methods", {method.name: {"batchable": method.config.batchable, "batch_dim": method.config.batch_dim if method.config.batchable else None} for method in __self.runner_methods} + yield "config", self.config.model_dump(flatten=True) + yield "llm_type", __self.llm_type + yield "runtime", self.runtime + yield "llm_tag", self.tag + + return types.new_class( + self.__class__.__name__ + "Runner", (bentoml.Runner,), + exec_body=lambda ns: ns.update({ + "llm_type": self.llm_type, + "identifying_params": self.identifying_params, + "llm_tag": self.tag, + "llm": self, # NOTE: self reference to LLM + "config": self.config, + "implementation": self.__llm_implementation__, + "peft_adapters": property(fget=available_adapters), + "download_model": self.ensure_model_id_exists, + "__call__": _wrapped_generate_run, + "embed": _wrapped_embeddings_run, + "__module__": self.__module__, + "__doc__": self.config["env"].start_docstring, + "__repr__": ReprMixin.__repr__, + "__repr_keys__": property(_wrapped_repr_keys), + "__repr_args__": _wrapped_repr_args, + "supports_embeddings": self["supports_embeddings"], + "supports_hf_agent": self["supports_generate_one"], + "has_adapters": self._adapters_mapping is not None, + }), + ) diff --git a/src/openllm/_prompt.py b/src/openllm/_prompt.py index 353a42b5..d8c54291 100644 --- a/src/openllm/_prompt.py +++ b/src/openllm/_prompt.py @@ -16,33 +16,25 @@ import string import typing as t class PromptFormatter(string.Formatter): - """This PromptFormatter is largely based on langchain's implementation.""" - - def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any: - if len(args) > 0: - raise ValueError("Positional arguments are not supported") - return super().vformat(format_string, args, kwargs) - - def check_unused_args( - self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any] - ) -> None: - """Check if extra params is passed.""" - extras = set(kwargs).difference(used_args) - if extras: - raise KeyError(f"Extra params passed: {extras}") - - def extract_template_variables(self, template: str) -> t.Sequence[str]: - """Extract template variables from a template string.""" - return [field[1] for field in self.parse(template) if field[1] is not None] + """This PromptFormatter is largely based on langchain's implementation.""" + def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any: + if len(args) > 0: raise ValueError("Positional arguments are not supported") + return super().vformat(format_string, args, kwargs) + def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None: + extras = set(kwargs).difference(used_args) + if extras: raise KeyError(f"Extra params passed: {extras}") + def extract_template_variables(self, template: str) -> t.Sequence[str]: return [field[1] for field in self.parse(template) if field[1] is not None] default_formatter = PromptFormatter() def process_prompt(prompt: str, template: str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str: - # Currently, all default prompt will always have `instruction` key. - if not use_prompt_template: return prompt - elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'") - template_variables = default_formatter.extract_template_variables(template) - prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} - if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'") - try: return template.format(instruction=prompt, **prompt_variables) - except KeyError as e: raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None + # Currently, all default prompt will always have `instruction` key. + if not use_prompt_template: return prompt + elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'") + template_variables = default_formatter.extract_template_variables(template) + prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} + if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'") + try: + return template.format(instruction=prompt, **prompt_variables) + except KeyError as e: + raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None diff --git a/src/openllm/_quantisation.py b/src/openllm/_quantisation.py index c6e24585..2141c607 100644 --- a/src/openllm/_quantisation.py +++ b/src/openllm/_quantisation.py @@ -25,119 +25,78 @@ from .utils import pkg # NOTE: We need to do this so that overload can register # correct overloads to typing registry if sys.version_info[:2] >= (3, 11): - from typing import overload + from typing import overload else: - from typing_extensions import overload - + from typing_extensions import overload if t.TYPE_CHECKING: - import auto_gptq as autogptq - import torch + import auto_gptq as autogptq + import torch - import openllm - import transformers + import openllm + import transformers - from ._types import DictStrAny + from ._types import DictStrAny else: - autogptq = LazyLoader("autogptq", globals(), "auto_gptq") - torch = LazyLoader("torch", globals(), "torch") - transformers = LazyLoader("transformers", globals(), "transformers") + autogptq = LazyLoader("autogptq", globals(), "auto_gptq") + torch = LazyLoader("torch", globals(), "torch") + transformers = LazyLoader("transformers", globals(), "transformers") logger = logging.getLogger(__name__) QuantiseMode = t.Literal["int8", "int4", "gptq"] - +# fmt: off @overload -def infer_quantisation_config( - cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["int8", "int4"], **attrs: t.Any -) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]: - ... - - +def infer_quantisation_config(cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["int8", "int4"], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]: ... @overload -def infer_quantisation_config( - cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["gptq"], **attrs: t.Any -) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]: - ... +def infer_quantisation_config(cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["gptq"], **attrs: t.Any) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]: ... +# fmt: on +def infer_quantisation_config(cls: type[openllm.LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]: + # 8 bit configuration + int8_threshold = attrs.pop("llm_int8_threshhold", 6.0) + int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False) + int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None) + int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False) + autogptq_attrs: DictStrAny = {"bits": attrs.pop("gptq_bits", 4), "group_size": attrs.pop("gptq_group_size", -1), "damp_percent": attrs.pop("gptq_damp_percent", 0.01), "desc_act": attrs.pop("gptq_desc_act", True), "sym": attrs.pop("gptq_sym", True), "true_sequential": attrs.pop("gptq_true_sequential", True),} -def infer_quantisation_config( - cls: type[openllm.LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any -) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]: - # 8 bit configuration - int8_threshold = attrs.pop("llm_int8_threshhold", 6.0) - int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False) - int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None) - int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False) + def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig: + if int8_skip_modules is None: int8_skip_modules = [] + if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm": + logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__) + int8_skip_modules.append("lm_head") + return transformers.BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload, llm_int8_threshhold=int8_threshold, llm_int8_skip_modules=int8_skip_modules, llm_int8_has_fp16_weight=int8_has_fp16_weight,) - autogptq_attrs: DictStrAny = { - "bits": attrs.pop("gptq_bits", 4), - "group_size": attrs.pop("gptq_group_size", -1), - "damp_percent": attrs.pop("gptq_damp_percent", 0.01), - "desc_act": attrs.pop("gptq_desc_act", True), - "sym": attrs.pop("gptq_sym", True), - "true_sequential": attrs.pop("gptq_true_sequential", True), - } + # 4 bit configuration + int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16) + int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4") + int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True) - def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig: - if int8_skip_modules is None: - int8_skip_modules = [] - if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm": - logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__) - int8_skip_modules.append("lm_head") - return transformers.BitsAndBytesConfig( - load_in_8bit=True, - llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload, - llm_int8_threshhold=int8_threshold, - llm_int8_skip_modules=int8_skip_modules, - llm_int8_has_fp16_weight=int8_has_fp16_weight, - ) - - # 4 bit configuration - int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16) - int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4") - int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True) - - # NOTE: Quantization setup - # quantize is a openllm.LLM feature, where we can quantize the model - # with bitsandbytes or quantization aware training. - if not is_bitsandbytes_available(): - raise RuntimeError( - "Quantization requires bitsandbytes to be installed. Make " - "sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'" - ) - if quantise == "int8": - quantisation_config = create_int8_config(int8_skip_modules) - elif quantise == "int4": - if is_transformers_supports_kbit(): - quantisation_config = transformers.BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=int4_compute_dtype, - bnb_4bit_quant_type=int4_quant_type, - bnb_4bit_use_double_quant=int4_use_double_quant, - ) - else: - logger.warning( - "'quantize' is set to int4, while the current transformers version %s does not support " - "k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore " - "make sure to install the latest version of transformers either via PyPI or " - "from git source: 'pip install git+https://github.com/huggingface/transformers'.", - pkg.pkg_version_info("transformers"), - ) - logger.warning("OpenLLM will fallback to 8-bit quantization.") - quantisation_config = create_int8_config(int8_skip_modules) - elif quantise == "gptq": - if not is_autogptq_available(): - logger.warning( - "'quantize=\"gptq\"' requires 'auto-gptq' to be installed (not available with local environment)." - " Make sure to have 'auto-gptq' available locally: 'pip install \"openllm[gptq]\"'. OpenLLM will fallback " - "to int8 with bitsandbytes." - ) - quantisation_config = create_int8_config(int8_skip_modules) - else: - quantisation_config = autogptq.BaseQuantizeConfig(**autogptq_attrs) + # NOTE: Quantization setup + # quantize is a openllm.LLM feature, where we can quantize the model + # with bitsandbytes or quantization aware training. + if not is_bitsandbytes_available(): raise RuntimeError("Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'") + if quantise == "int8": quantisation_config = create_int8_config(int8_skip_modules) + elif quantise == "int4": + if is_transformers_supports_kbit(): quantisation_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=int4_compute_dtype, bnb_4bit_quant_type=int4_quant_type, bnb_4bit_use_double_quant=int4_use_double_quant,) else: - raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.") - - return quantisation_config, attrs + logger.warning( + "'quantize' is set to int4, while the current transformers version %s does not support " + "k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore " + "make sure to install the latest version of transformers either via PyPI or " + "from git source: 'pip install git+https://github.com/huggingface/transformers'.", pkg.pkg_version_info("transformers"), + ) + logger.warning("OpenLLM will fallback to 8-bit quantization.") + quantisation_config = create_int8_config(int8_skip_modules) + elif quantise == "gptq": + if not is_autogptq_available(): + logger.warning("'quantize=\"gptq\"' requires 'auto-gptq' to be installed (not available with local environment)." + " Make sure to have 'auto-gptq' available locally: 'pip install \"openllm[gptq]\"'. OpenLLM will fallback " + "to int8 with bitsandbytes.") + quantisation_config = create_int8_config(int8_skip_modules) + else: + quantisation_config = autogptq.BaseQuantizeConfig(**autogptq_attrs) + else: + raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.") + return quantisation_config, attrs diff --git a/src/openllm/_schema.py b/src/openllm/_schema.py index 7c9e928e..0855dfde 100644 --- a/src/openllm/_schema.py +++ b/src/openllm/_schema.py @@ -27,109 +27,71 @@ from .utils import bentoml_cattr from .utils import requires_dependencies if t.TYPE_CHECKING: - import vllm + import vllm - from ._types import DictStrAny + from ._types import DictStrAny else: - DictStrAny = dict - vllm = LazyLoader("vllm", globals(), "vllm") - + DictStrAny = dict + vllm = LazyLoader("vllm", globals(), "vllm") @attr.frozen(slots=True) class GenerationInput: - prompt: str - """The prompt to be sent to system.""" + prompt: str + """The prompt to be sent to system.""" + llm_config: LLMConfig + """A mapping of given LLM configuration values for given system.""" + @staticmethod + def convert_llm_config(data: dict[str, t.Any] | LLMConfig, cls: type[LLMConfig] | None = None) -> LLMConfig: + if isinstance(data, LLMConfig): return data + elif LazyType(DictStrAny).isinstance(data): + if cls is None: raise ValueError("'cls' must pass if given data is a dictionary.") + return cls(**data) + else: + raise RuntimeError(f"Type {type(data)} is not yet supported.") - llm_config: LLMConfig - """A mapping of given LLM configuration values for given system.""" - - @staticmethod - def convert_llm_config(data: dict[str, t.Any] | LLMConfig, cls: type[LLMConfig] | None = None) -> LLMConfig: - if isinstance(data, LLMConfig): - return data - elif LazyType(DictStrAny).isinstance(data): - if cls is None: - raise ValueError("'cls' must pass if given data is a dictionary.") - return cls(**data) - else: - raise RuntimeError(f"Type {type(data)} is not yet supported.") - - @classmethod - def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]: - from .models.auto import AutoConfig - - llm_config = AutoConfig.for_model(model_name, **attrs) - return attr.make_class( - inflection.camelize(llm_config["model_name"]) + "GenerationInput", - attrs={ - "prompt": attr.field(type=str), - "llm_config": attr.field( - type=llm_config.__class__, - default=llm_config, - converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__), - ), - }, - ) - - def model_dump(self) -> dict[str, t.Any]: - return {"prompt": self.prompt, "llm_config": self.llm_config.model_dump(flatten=True)} + @classmethod + def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]: + from .models.auto import AutoConfig + llm_config = AutoConfig.for_model(model_name, **attrs) + return attr.make_class(inflection.camelize(llm_config["model_name"]) + "GenerationInput", attrs={"prompt": attr.field(type=str), "llm_config": attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__))}) + def model_dump(self) -> dict[str, t.Any]: + return {"prompt": self.prompt, "llm_config": self.llm_config.model_dump(flatten=True)} @attr.frozen(slots=True) class GenerationOutput: - responses: t.List[t.Any] - """A list of responses from the system.""" + responses: t.List[t.Any] + """A list of responses from the system.""" + configuration: t.Dict[str, t.Any] + """A mapping of configuration values for given system.""" + @property + def marshaled_config(self) -> GenerationConfig: + return bentoml_cattr.structure(self.configuration, GenerationConfig) - configuration: t.Dict[str, t.Any] - """A mapping of configuration values for given system.""" - - @property - def marshaled_config(self) -> GenerationConfig: - return bentoml_cattr.structure(self.configuration, GenerationConfig) - - @property - def unmarshaled(self) -> dict[str, t.Any]: - return bentoml_cattr.unstructure(self) - - def __getitem__(self, key: str) -> t.Any: - if hasattr(self, key): return getattr(self, key) - elif key in self.configuration: return self.configuration[key] - else: raise KeyError(key) + @property + def unmarshaled(self) -> dict[str, t.Any]: + return bentoml_cattr.unstructure(self) + def __getitem__(self, key: str) -> t.Any: + if hasattr(self, key): return getattr(self, key) + elif key in self.configuration: return self.configuration[key] + else: raise KeyError(key) @attr.frozen(slots=True) class MetadataOutput: - model_id: str - timeout: int - model_name: str - framework: str - configuration: str - supports_embeddings: bool - supports_hf_agent: bool - + model_id: str + timeout: int + model_name: str + framework: str + configuration: str + supports_embeddings: bool + supports_hf_agent: bool @attr.frozen(slots=True) class EmbeddingsOutput: - embeddings: t.List[t.List[float]] - num_tokens: int - + embeddings: t.List[t.List[float]] + num_tokens: int @requires_dependencies("vllm", extra="vllm") def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> DictStrAny: - return dict( - request_id=request_output.request_id, - prompt=request_output.prompt, - finished=request_output.finished, - prompt_token_ids=request_output.prompt_token_ids, - outputs=[ - dict( - index=it.index, - text=it.text, - token_ids=it.token_ids, - cumulative_logprob=it.cumulative_logprob, - logprobs=it.logprobs, - finish_reason=it.finish_reason, - ) - for it in request_output.outputs - ], - ) + return dict(request_id=request_output.request_id, prompt=request_output.prompt, finished=request_output.finished, prompt_token_ids=request_output.prompt_token_ids, outputs=[dict(index=it.index, text=it.text, token_ids=it.token_ids, cumulative_logprob=it.cumulative_logprob, logprobs=it.logprobs, finish_reason=it.finish_reason) for it in request_output.outputs],) diff --git a/src/openllm/_service.py b/src/openllm/_service.py index 9937bb65..9d79ba26 100644 --- a/src/openllm/_service.py +++ b/src/openllm/_service.py @@ -25,8 +25,8 @@ from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route if t.TYPE_CHECKING: - from starlette.requests import Request - from starlette.responses import Response + from starlette.requests import Request + from starlette.responses import Response # The following warnings from bitsandbytes, and probably not that important for users to see warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization") warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization") @@ -36,37 +36,59 @@ adapter_map = os.environ.get("OPENLLM_ADAPTER_MAP", """{__model_adapter_map__}"" llm_config = openllm.AutoConfig.for_model(model) runner = openllm.Runner(model, llm_config=llm_config, ensure_available=False, adapter_map=orjson.loads(adapter_map)) svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[runner]) + @svc.api(input=bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True)}), output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)}), route="/v1/generate") async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput: - qa_inputs = openllm.GenerationInput.for_model(model)(**input_dict) - config = qa_inputs.llm_config.model_dump() - responses = await runner.generate.async_run(qa_inputs.prompt, **config) - return openllm.GenerationOutput(responses=responses, configuration=config) + qa_inputs = openllm.GenerationInput.for_model(model)(**input_dict) + config = qa_inputs.llm_config.model_dump() + responses = await runner.generate.async_run(qa_inputs.prompt, **config) + return openllm.GenerationOutput(responses=responses, configuration=config) + @svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample({"model_id": runner.llm.model_id, "timeout": 3600, "model_name": llm_config["model_name"], "framework": "pt", "configuration": "", "supports_embeddings": runner.supports_embeddings, "supports_hf_agent": runner.supports_hf_agent}), route="/v1/metadata") -def metadata_v1(_: str) -> openllm.MetadataOutput: return openllm.MetadataOutput(model_id=runner.llm.model_id, timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent,) +def metadata_v1(_: str) -> openllm.MetadataOutput: + return openllm.MetadataOutput(model_id=runner.llm.model_id, timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent,) + if runner.supports_embeddings: - @svc.api(input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20}), route="/v1/embeddings") - async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput: - responses = await runner.embeddings.async_run(phrases) - return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"]) + + @svc.api( + input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({ + "embeddings": [ + 0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, + 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076 + ], "num_tokens": 20 + }), route="/v1/embeddings" + ) + async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput: + responses = await runner.embeddings.async_run(phrases) + return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"]) + if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent(): - @attr.define - class HfAgentInput: - inputs: str - parameters: t.Dict[str, t.Any] - async def hf_agent(request: Request) -> Response: - json_str = await request.body() - try: input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), HfAgentInput) - except orjson.JSONDecodeError as err: raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None - stop = input_data.parameters.pop("stop", ["\n"]) - try: return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200) - except NotImplementedError: return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500) - hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])]) - svc.mount_asgi_app(hf_app, path="/hf") + + @attr.define + class HfAgentInput: + inputs: str + parameters: t.Dict[str, t.Any] + + async def hf_agent(request: Request) -> Response: + json_str = await request.body() + try: + input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), HfAgentInput) + except orjson.JSONDecodeError as err: + raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None + stop = input_data.parameters.pop("stop", ["\n"]) + try: + return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200) + except NotImplementedError: + return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500) + + hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])]) + svc.mount_asgi_app(hf_app, path="/hf") + async def list_adapter_v1(_: Request) -> Response: - res: dict[str, t.Any] = {} - if runner.peft_adapters["success"] is True: res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()} - res.update({"success": runner.peft_adapters["success"], "error_msg": runner.peft_adapters["error_msg"]}) - return JSONResponse(res, status_code=200) + res: dict[str, t.Any] = {} + if runner.peft_adapters["success"] is True: res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()} + res.update({"success": runner.peft_adapters["success"], "error_msg": runner.peft_adapters["error_msg"]}) + return JSONResponse(res, status_code=200) + adapters_app_v1 = Starlette(debug=True, routes=[Route("/adapters", list_adapter_v1, methods=["GET"])]) svc.mount_asgi_app(adapters_app_v1, path="/v1") diff --git a/src/openllm/_strategies.py b/src/openllm/_strategies.py index 69465163..d84946ae 100644 --- a/src/openllm/_strategies.py +++ b/src/openllm/_strategies.py @@ -35,223 +35,220 @@ from .utils import LazyType from .utils import ReprMixin if t.TYPE_CHECKING: - ListIntStr = list[int | str] - class DynResource(bentoml.Resource[t.List[str]], resource_id=""): - resource_id: t.ClassVar[str] + ListIntStr = list[int | str] + + class DynResource(bentoml.Resource[t.List[str]], resource_id=""): + resource_id: t.ClassVar[str] else: - DynResource = bentoml.Resource[t.List[str]] - ListIntStr = list + DynResource = bentoml.Resource[t.List[str]] + ListIntStr = list # NOTE: We need to do this so that overload can register # correct overloads to typing registry if sys.version_info[:2] >= (3, 11): - from typing import overload + from typing import overload else: - from typing_extensions import overload + from typing_extensions import overload logger = logging.getLogger(__name__) def _strtoul(s: str) -> int: - """Return -1 or positive integer sequence string starts with,.""" - if not s: return -1 - idx = 0 - for idx, c in enumerate(s): - if not (c.isdigit() or (idx == 0 and c in "+-")): break - if idx + 1 == len(s): idx += 1 # noqa: PLW2901 - # NOTE: idx will be set via enumerate - return int(s[:idx]) if idx > 0 else -1 - + """Return -1 or positive integer sequence string starts with,.""" + if not s: return -1 + idx = 0 + for idx, c in enumerate(s): + if not (c.isdigit() or (idx == 0 and c in "+-")): break + if idx + 1 == len(s): idx += 1 # noqa: PLW2901 + # NOTE: idx will be set via enumerate + return int(s[:idx]) if idx > 0 else -1 def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]: - rcs: list[str] = [] - for elem in lst.split(","): - # Repeated id results in empty set - if elem in rcs: return [] - # Anything other but prefix is ignored - if not elem.startswith(prefix): break - rcs.append(elem) - return rcs - + rcs: list[str] = [] + for elem in lst.split(","): + # Repeated id results in empty set + if elem in rcs: return [] + # Anything other but prefix is ignored + if not elem.startswith(prefix): break + rcs.append(elem) + return rcs _STACK_LEVEL = 3 @overload -def _parse_visible_devices(default_var: str | None = ..., respect_env: t.Literal[True] = True) -> list[str] | None: ... +def _parse_visible_devices(default_var: str | None = ..., respect_env: t.Literal[True] = True) -> list[str] | None: + ... + @overload -def _parse_visible_devices(default_var: str = ..., respect_env: t.Literal[False] = ...) -> list[str]: ... +def _parse_visible_devices(default_var: str = ..., respect_env: t.Literal[False] = ...) -> list[str]: + ... + def _parse_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None: - """CUDA_VISIBLE_DEVICES aware with default var for parsing spec.""" - if respect_env: - spec = os.getenv("CUDA_VISIBLE_DEVICES", default_var) - if not spec: return None - else: - if default_var is None: raise ValueError("spec is required to be not None when parsing spec.") - spec = default_var + """CUDA_VISIBLE_DEVICES aware with default var for parsing spec.""" + if respect_env: + spec = os.getenv("CUDA_VISIBLE_DEVICES", default_var) + if not spec: return None + else: + if default_var is None: raise ValueError("spec is required to be not None when parsing spec.") + spec = default_var - if spec.startswith("GPU-"): return _parse_list_with_prefix(spec, "GPU-") - if spec.startswith("MIG-"): return _parse_list_with_prefix(spec, "MIG-") - - # XXX: We to somehow handle cases such as '100m' - # CUDA_VISIBLE_DEVICES uses something like strtoul - # which makes `1gpu2,2ampere` is equivalent to `1,2` - rc: list[int] = [] - for el in spec.split(","): - x = _strtoul(el.strip()) - # Repeated ordinal results in empty set - if x in rc: return [] - # Negative value aborts the sequence - if x < 0: break - rc.append(x) - return [str(i) for i in rc] + if spec.startswith("GPU-"): return _parse_list_with_prefix(spec, "GPU-") + if spec.startswith("MIG-"): return _parse_list_with_prefix(spec, "MIG-") + # XXX: We to somehow handle cases such as '100m' + # CUDA_VISIBLE_DEVICES uses something like strtoul + # which makes `1gpu2,2ampere` is equivalent to `1,2` + rc: list[int] = [] + for el in spec.split(","): + x = _strtoul(el.strip()) + # Repeated ordinal results in empty set + if x in rc: return [] + # Negative value aborts the sequence + if x < 0: break + rc.append(x) + return [str(i) for i in rc] def _from_system(cls: type[DynResource]) -> list[str]: - """Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation. + """Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation. - It relies on torch.cuda implementation and in turns respect CUDA_VISIBLE_DEVICES. - """ - visible_devices = _parse_visible_devices() - if visible_devices is None: - if cls.resource_id == "amd.com/gpu": - if not psutil.LINUX: - if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL) - return [] + It relies on torch.cuda implementation and in turns respect CUDA_VISIBLE_DEVICES. + """ + visible_devices = _parse_visible_devices() + if visible_devices is None: + if cls.resource_id == "amd.com/gpu": + if not psutil.LINUX: + if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL) + return [] - # ROCm does not currently have the rocm_smi wheel. - # So we need to use the ctypes bindings directly. - # we don't want to use CLI because parsing is a pain. - sys.path.append("/opt/rocm/libexec/rocm_smi") - try: - from ctypes import byref - from ctypes import c_uint32 + # ROCm does not currently have the rocm_smi wheel. + # So we need to use the ctypes bindings directly. + # we don't want to use CLI because parsing is a pain. + sys.path.append("/opt/rocm/libexec/rocm_smi") + try: + from ctypes import byref + from ctypes import c_uint32 - # refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py - from rsmiBindings import rocmsmi - from rsmiBindings import rsmi_status_t - - device_count = c_uint32(0) - ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count)) - if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)] - return [] - # In this case the binary is not found, returning empty list - except (ModuleNotFoundError, ImportError): return [] - finally: sys.path.remove("/opt/rocm/libexec/rocm_smi") - else: - try: - from cuda import cuda - cuda.cuInit(0) - _, dev = cuda.cuDeviceGetCount() - return [str(i) for i in range(dev)] - except (ImportError, RuntimeError, AttributeError): return [] - return visible_devices + # refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py + from rsmiBindings import rocmsmi + from rsmiBindings import rsmi_status_t + device_count = c_uint32(0) + ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count)) + if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)] + return [] + # In this case the binary is not found, returning empty list + except (ModuleNotFoundError, ImportError): + return [] + finally: + sys.path.remove("/opt/rocm/libexec/rocm_smi") + else: + try: + from cuda import cuda + cuda.cuInit(0) + _, dev = cuda.cuDeviceGetCount() + return [str(i) for i in range(dev)] + except (ImportError, RuntimeError, AttributeError): + return [] + return visible_devices @overload -def _from_spec(cls: type[DynResource], spec: int) -> list[str]: ... +def _from_spec(cls: type[DynResource], spec: int) -> list[str]: + ... + @overload -def _from_spec(cls: type[DynResource], spec: ListIntStr) -> list[str]: ... +def _from_spec(cls: type[DynResource], spec: ListIntStr) -> list[str]: + ... + @overload -def _from_spec(cls: type[DynResource], spec: str) -> list[str]: ... +def _from_spec(cls: type[DynResource], spec: str) -> list[str]: + ... + def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]: - """Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation. - - The parser behaves similar to how PyTorch handles CUDA_VISIBLE_DEVICES. This means within - BentoML's resource configuration, its behaviour is similar to CUDA_VISIBLE_DEVICES. - """ - if isinstance(spec, int): - if spec in (-1, 0): return [] - if spec < -1: raise ValueError("Spec cannot be < -1.") - return [str(i) for i in range(spec)] - elif isinstance(spec, str): - if not spec: return [] - if spec.isdigit(): spec = ",".join([str(i) for i in range(_strtoul(spec))]) - return _parse_visible_devices(spec, respect_env=False) - elif LazyType(ListIntStr).isinstance(spec): return [str(x) for x in spec] - else: raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.") + """Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation. + The parser behaves similar to how PyTorch handles CUDA_VISIBLE_DEVICES. This means within + BentoML's resource configuration, its behaviour is similar to CUDA_VISIBLE_DEVICES. + """ + if isinstance(spec, int): + if spec in (-1, 0): return [] + if spec < -1: raise ValueError("Spec cannot be < -1.") + return [str(i) for i in range(spec)] + elif isinstance(spec, str): + if not spec: return [] + if spec.isdigit(): spec = ",".join([str(i) for i in range(_strtoul(spec))]) + return _parse_visible_devices(spec, respect_env=False) + elif LazyType(ListIntStr).isinstance(spec): + return [str(x) for x in spec] + else: + raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.") def _raw_device_uuid_nvml() -> list[str] | None: - """Return list of device UUID as reported by NVML or None if NVML discovery/initialization failed.""" - from ctypes import CDLL - from ctypes import byref - from ctypes import c_int - from ctypes import c_void_p - from ctypes import create_string_buffer + """Return list of device UUID as reported by NVML or None if NVML discovery/initialization failed.""" + from ctypes import CDLL + from ctypes import byref + from ctypes import c_int + from ctypes import c_void_p + from ctypes import create_string_buffer - try: nvml_h = CDLL("libnvidia-ml.so.1") - except Exception: - warnings.warn("Failed to find nvidia binding", stacklevel=_STACK_LEVEL) - return None + try: + nvml_h = CDLL("libnvidia-ml.so.1") + except Exception: + warnings.warn("Failed to find nvidia binding", stacklevel=_STACK_LEVEL) + return None - rc = nvml_h.nvmlInit() + rc = nvml_h.nvmlInit() + if rc != 0: + warnings.warn("Can't initialize NVML", stacklevel=_STACK_LEVEL) + return None + dev_count = c_int(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) + if rc != 0: + warnings.warn("Failed to get available device from system.", stacklevel=_STACK_LEVEL) + return None + uuids: list[str] = [] + for idx in range(dev_count.value): + dev_id = c_void_p() + rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id)) if rc != 0: - warnings.warn("Can't initialize NVML", stacklevel=_STACK_LEVEL) - return None - dev_count = c_int(-1) - rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) + warnings.warn(f"Failed to get device handle for {idx}", stacklevel=_STACK_LEVEL) + return None + buf_len = 96 + buf = create_string_buffer(buf_len) + rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len) if rc != 0: - warnings.warn("Failed to get available device from system.", stacklevel=_STACK_LEVEL) - return None - uuids: list[str] = [] - for idx in range(dev_count.value): - dev_id = c_void_p() - rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id)) - if rc != 0: - warnings.warn(f"Failed to get device handle for {idx}", stacklevel=_STACK_LEVEL) - return None - buf_len = 96 - buf = create_string_buffer(buf_len) - rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len) - if rc != 0: - warnings.warn(f"Failed to get device UUID for {idx}", stacklevel=_STACK_LEVEL) - return None - uuids.append(buf.raw.decode("ascii").strip("\0")) - del nvml_h - return uuids - + warnings.warn(f"Failed to get device UUID for {idx}", stacklevel=_STACK_LEVEL) + return None + uuids.append(buf.raw.decode("ascii").strip("\0")) + del nvml_h + return uuids def _validate(cls: type[DynResource], val: list[t.Any]) -> None: - if cls.resource_id == "amd.com/gpu": - raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'") - if not all(isinstance(i, str) for i in val): raise ValueError("Input list should be all string type.") + if cls.resource_id == "amd.com/gpu": + raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'") + if not all(isinstance(i, str) for i in val): raise ValueError("Input list should be all string type.") - try: - from cuda import cuda - - err, *_ = cuda.cuInit(0) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("Failed to initialise CUDA runtime binding.") - # correctly parse handle - for el in val: - if el.startswith("GPU-") or el.startswith("MIG-"): - uuids = _raw_device_uuid_nvml() - if uuids is None: raise ValueError("Failed to parse available GPUs UUID") - if el not in uuids: raise ValueError(f"Given UUID {el} is not found with available UUID (available: {uuids})") - elif el.isdigit(): - err, _ = cuda.cuDeviceGet(int(el)) - if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f"Failed to get device {el}") - except (ImportError, RuntimeError): - pass + try: + from cuda import cuda + err, *_ = cuda.cuInit(0) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Failed to initialise CUDA runtime binding.") + # correctly parse handle + for el in val: + if el.startswith("GPU-") or el.startswith("MIG-"): + uuids = _raw_device_uuid_nvml() + if uuids is None: raise ValueError("Failed to parse available GPUs UUID") + if el not in uuids: raise ValueError(f"Given UUID {el} is not found with available UUID (available: {uuids})") + elif el.isdigit(): + err, _ = cuda.cuDeviceGet(int(el)) + if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f"Failed to get device {el}") + except (ImportError, RuntimeError): + pass def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[DynResource]: - return types.new_class( - name, - (DynResource, ReprMixin), - {"resource_id": resource_kind}, - lambda ns: ns.update( - { - "resource_id": resource_kind, - "from_spec": classmethod(_from_spec), - "from_system": classmethod(_from_system), - "validate": classmethod(_validate), - "__repr_keys__": property(lambda _: {"resource_id"}), - "__doc__": inspect.cleandoc(docstring), - "__module__": "openllm._strategies", - } - ), - ) + return types.new_class( + name, (DynResource, ReprMixin), {"resource_id": resource_kind}, lambda ns: ns.update({"resource_id": resource_kind, "from_spec": classmethod(_from_spec), "from_system": classmethod(_from_system), "validate": classmethod(_validate), "__repr_keys__": property(lambda _: {"resource_id"}), "__doc__": inspect.cleandoc(docstring), "__module__": "openllm._strategies",}), + ) # NOTE: we need to hint these t.Literal since mypy is to dumb to infer this as literal :facepalm: _TPU_RESOURCE: t.Literal["cloud-tpus.google.com/v2"] = "cloud-tpus.google.com/v2" @@ -259,159 +256,146 @@ _AMD_GPU_RESOURCE: t.Literal["amd.com/gpu"] = "amd.com/gpu" _NVIDIA_GPU_RESOURCE: t.Literal["nvidia.com/gpu"] = "nvidia.com/gpu" _CPU_RESOURCE: t.Literal["cpu"] = "cpu" -NvidiaGpuResource = _make_resource_class( - "NvidiaGpuResource", - _NVIDIA_GPU_RESOURCE, - """NVIDIA GPU resource. +NvidiaGpuResource = _make_resource_class("NvidiaGpuResource", _NVIDIA_GPU_RESOURCE, """NVIDIA GPU resource. This is a modified version of internal's BentoML's NvidiaGpuResource - where it respects and parse CUDA_VISIBLE_DEVICES correctly.""", -) -AmdGpuResource = _make_resource_class( - "AmdGpuResource", - _AMD_GPU_RESOURCE, - """AMD GPU resource. + where it respects and parse CUDA_VISIBLE_DEVICES correctly.""",) +AmdGpuResource = _make_resource_class("AmdGpuResource", _AMD_GPU_RESOURCE, """AMD GPU resource. Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to - ``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""", -) + ``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""",) LiteralResourceSpec = t.Literal["cloud-tpus.google.com/v2", "amd.com/gpu", "nvidia.com/gpu", "cpu"] # convenient mapping -def resource_spec(name: t.Literal["tpu", "amd", "nvidia", "cpu"]) -> LiteralResourceSpec: - if name == "tpu": return _TPU_RESOURCE - elif name == "amd": return _AMD_GPU_RESOURCE - elif name == "nvidia": return _NVIDIA_GPU_RESOURCE - elif name == "cpu": return _CPU_RESOURCE - else: raise ValueError("Unknown alias. Accepted: ['tpu', 'amd', 'nvidia', 'cpu']") +def resource_spec(name: t.Literal["tpu", "amd", "nvidia", "cpu"]) -> LiteralResourceSpec: + if name == "tpu": return _TPU_RESOURCE + elif name == "amd": return _AMD_GPU_RESOURCE + elif name == "nvidia": return _NVIDIA_GPU_RESOURCE + elif name == "cpu": return _CPU_RESOURCE + else: raise ValueError("Unknown alias. Accepted: ['tpu', 'amd', 'nvidia', 'cpu']") + @functools.lru_cache def available_resource_spec() -> tuple[LiteralResourceSpec, ...]: - """This is a utility function helps to determine the available resources from given running system. + """This is a utility function helps to determine the available resources from given running system. - It will first check for TPUs -> AMD GPUS -> NVIDIA GPUS -> CPUs. - - TODO: Supports TPUs - """ - available = () - if len(AmdGpuResource.from_system()) > 0: available += (_AMD_GPU_RESOURCE,) - if len(NvidiaGpuResource.from_system()) > 0: available += (_NVIDIA_GPU_RESOURCE,) - available += (_CPU_RESOURCE,) - return t.cast(t.Tuple[LiteralResourceSpec, ...], available) + It will first check for TPUs -> AMD GPUS -> NVIDIA GPUS -> CPUs. + TODO: Supports TPUs + """ + available = () + if len(AmdGpuResource.from_system()) > 0: available += (_AMD_GPU_RESOURCE,) + if len(NvidiaGpuResource.from_system()) > 0: available += (_NVIDIA_GPU_RESOURCE,) + available += (_CPU_RESOURCE,) + return t.cast(t.Tuple[LiteralResourceSpec, ...], available) class CascadingResourceStrategy(bentoml.Strategy, ReprMixin): - """This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource. + """This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource. - It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU. - See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices - for ROCm's support for CUDA_VISIBLE_DEVICES. + It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU. + See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices + for ROCm's support for CUDA_VISIBLE_DEVICES. - TODO: Support CloudTPUResource + TODO: Support CloudTPUResource + """ + @classmethod + def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float) -> int: + if resource_request is None: resource_request = system_resources() + + def _get_gpu_count(typ: list[str] | None, kind: str) -> int | None: + if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: return math.ceil(len(typ) * workers_per_resource) + + # use NVIDIA + kind = "nvidia.com/gpu" + count = _get_gpu_count(get_resource(resource_request, kind), kind) + if count: return count + # use AMD + kind = "amd.com/gpu" + count = _get_gpu_count(get_resource(resource_request, kind, validate=False), kind) + if count: return count + # use CPU + cpus = get_resource(resource_request, "cpu") + if cpus is not None and cpus > 0: + if "cpu" not in runnable_class.SUPPORTED_RESOURCES: logger.warning("No known supported resource available for %s, falling back to using CPU.", runnable_class) + + if runnable_class.SUPPORTS_CPU_MULTI_THREADING: + if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError("Fractional CPU multi threading support is not yet supported.") + return int(workers_per_resource) + return math.ceil(cpus) * workers_per_resource + + # this should not be reached by user since we always read system resource as default + raise ValueError(f"No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.") + + @classmethod + def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]: + """Get worker env for this given worker_index. + + Args: + runnable_class: The runnable class to be run. + resource_request: The resource request of the runnable. + workers_per_resource: # of workers per resource. + worker_index: The index of the worker, start from 0. """ - @classmethod - def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float) -> int: - if resource_request is None: resource_request = system_resources() - - def _get_gpu_count(typ: list[str] | None, kind: str) -> int | None: - if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: return math.ceil(len(typ) * workers_per_resource) - - # use NVIDIA - kind = "nvidia.com/gpu" - count = _get_gpu_count(get_resource(resource_request, kind), kind) - if count: return count - - # use AMD - kind = "amd.com/gpu" - count = _get_gpu_count(get_resource(resource_request, kind, validate=False), kind) - if count: return count - - # use CPU - cpus = get_resource(resource_request, "cpu") - if cpus is not None and cpus > 0: - if "cpu" not in runnable_class.SUPPORTED_RESOURCES: logger.warning("No known supported resource available for %s, falling back to using CPU.", runnable_class) - - if runnable_class.SUPPORTS_CPU_MULTI_THREADING: - if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError("Fractional CPU multi threading support is not yet supported.") - return int(workers_per_resource) - - return math.ceil(cpus) * workers_per_resource - - # this should not be reached by user since we always read system resource as default - raise ValueError(f"No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.") - @classmethod - def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]: - """Get worker env for this given worker_index. - - Args: - runnable_class: The runnable class to be run. - resource_request: The resource request of the runnable. - workers_per_resource: # of workers per resource. - worker_index: The index of the worker, start from 0. - """ - cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None) - disabled = cuda_env in ("", "-1") - - environ: dict[str, t.Any] = {} - - if resource_request is None: resource_request = system_resources() - - # use NVIDIA - kind = "nvidia.com/gpu" - typ = get_resource(resource_request, kind) - if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: - if disabled: - logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index) - environ["CUDA_VISIBLE_DEVICES"] = cuda_env - return environ - environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index) - logger.debug("Environ for worker %s: %s", worker_index, environ) - return environ - - # use AMD - kind = "amd.com/gpu" - typ = get_resource(resource_request, kind, validate=False) - if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: - if disabled: - logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index) - environ["CUDA_VISIBLE_DEVICES"] = cuda_env - return environ - environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index) - logger.debug("Environ for worker %s: %s", worker_index, environ) - return environ - - # use CPU - cpus = get_resource(resource_request, "cpu") - if cpus is not None and cpus > 0: - environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu - if runnable_class.SUPPORTS_CPU_MULTI_THREADING: - thread_count = math.ceil(cpus) - for thread_env in THREAD_ENVS: environ[thread_env] = os.getenv(thread_env, str(thread_count)) - logger.debug("Environ for worker %s: %s", worker_index, environ) - return environ - for thread_env in THREAD_ENVS: environ[thread_env] = os.getenv(thread_env, "1") - return environ + cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None) + disabled = cuda_env in ("", "-1") + environ: dict[str, t.Any] = {} + if resource_request is None: resource_request = system_resources() + # use NVIDIA + kind = "nvidia.com/gpu" + typ = get_resource(resource_request, kind) + if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: + if disabled: + logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index) + environ["CUDA_VISIBLE_DEVICES"] = cuda_env return environ + environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index) + logger.debug("Environ for worker %s: %s", worker_index, environ) + return environ + # use AMD + kind = "amd.com/gpu" + typ = get_resource(resource_request, kind, validate=False) + if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: + if disabled: + logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index) + environ["CUDA_VISIBLE_DEVICES"] = cuda_env + return environ + environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index) + logger.debug("Environ for worker %s: %s", worker_index, environ) + return environ + # use CPU + cpus = get_resource(resource_request, "cpu") + if cpus is not None and cpus > 0: + environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu + if runnable_class.SUPPORTS_CPU_MULTI_THREADING: + thread_count = math.ceil(cpus) + for thread_env in THREAD_ENVS: + environ[thread_env] = os.getenv(thread_env, str(thread_count)) + logger.debug("Environ for worker %s: %s", worker_index, environ) + return environ + for thread_env in THREAD_ENVS: + environ[thread_env] = os.getenv(thread_env, "1") + return environ + return environ - @staticmethod - def transpile_workers_to_cuda_envvar(workers_per_resource: float | int, gpus: list[str], worker_index: int) -> str: - # Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string. - if isinstance(workers_per_resource, float): - # NOTE: We hit this branch when workers_per_resource is set to - # float, for example 0.5 or 0.25 - if workers_per_resource > 1: - raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.") - # We are round the assigned resource here. This means if workers_per_resource=.4 - # then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2. - assigned_resource_per_worker = round(1 / workers_per_resource) - if len(gpus) < assigned_resource_per_worker: - logger.warning("Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])", gpus, worker_index, assigned_resource_per_worker) - raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].") - assigned_gpu = gpus[assigned_resource_per_worker * worker_index : assigned_resource_per_worker * (worker_index + 1) ] - dev = ",".join(assigned_gpu) - else: - idx = worker_index // workers_per_resource - if idx >= len(gpus): raise ValueError(f"Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}") - dev = str(gpus[idx]) - return dev + @staticmethod + def transpile_workers_to_cuda_envvar(workers_per_resource: float | int, gpus: list[str], worker_index: int) -> str: + # Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string. + if isinstance(workers_per_resource, float): + # NOTE: We hit this branch when workers_per_resource is set to + # float, for example 0.5 or 0.25 + if workers_per_resource > 1: + raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.") + # We are round the assigned resource here. This means if workers_per_resource=.4 + # then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2. + assigned_resource_per_worker = round(1 / workers_per_resource) + if len(gpus) < assigned_resource_per_worker: + logger.warning("Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])", gpus, worker_index, assigned_resource_per_worker) + raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].") + assigned_gpu = gpus[assigned_resource_per_worker * worker_index:assigned_resource_per_worker * (worker_index+1)] + dev = ",".join(assigned_gpu) + else: + idx = worker_index // workers_per_resource + if idx >= len(gpus): raise ValueError(f"Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}") + dev = str(gpus[idx]) + return dev diff --git a/src/openllm/_types.py b/src/openllm/_types.py index 60749bf0..49f4202b 100644 --- a/src/openllm/_types.py +++ b/src/openllm/_types.py @@ -19,10 +19,8 @@ It will raises a RuntimeError if this is imported eagerly. from __future__ import annotations import typing as t - if not t.TYPE_CHECKING: raise RuntimeError(f"{__name__} should not be imported during runtime") - import attr import bentoml @@ -31,17 +29,15 @@ from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict from ._configuration import AdapterType from ._configuration import LiteralRuntime as LiteralRuntime - if t.TYPE_CHECKING: - import peft - - import openllm - from openllm._llm import M as _M - from openllm._llm import T as _T - from bentoml._internal.runner.runnable import RunnableMethod - from bentoml._internal.runner.runner import RunnerMethod - from bentoml._internal.runner.strategy import Strategy + import peft + import openllm + from openllm._llm import M as _M + from openllm._llm import T as _T + from bentoml._internal.runner.runnable import RunnableMethod + from bentoml._internal.runner.runner import RunnerMethod + from bentoml._internal.runner.strategy import Strategy AnyCallable = t.Callable[..., t.Any] DictStrAny = dict[str, t.Any] @@ -54,72 +50,72 @@ T = t.TypeVar("T") Ts = t.TypeVarTuple("Ts") At = t.TypeVar("At", bound=attr.AttrsInstance) - class PeftAdapterOutput(t.TypedDict): - success: bool - result: dict[str, peft.PeftConfig] - error_msg: str - + success: bool + result: dict[str, peft.PeftConfig] + error_msg: str class LLMEmbeddings(t.TypedDict): - embeddings: t.List[t.List[float]] - num_tokens: int - + embeddings: t.List[t.List[float]] + num_tokens: int class AdaptersTuple(TupleAny): - adapter_id: str - name: str | None - config: DictStrAny - + adapter_id: str + name: str | None + config: DictStrAny AdaptersMapping = dict[AdapterType, tuple[AdaptersTuple, ...]] - class LLMRunnable(bentoml.Runnable, t.Generic[_M, _T]): - SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu") - SUPPORTS_CPU_MULTI_THREADING = True - __call__: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]] - set_adapter: RunnableMethod[LLMRunnable[_M, _T], [str], dict[t.Literal["success", "error_msg"], bool | str]] - embeddings: RunnableMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings] - generate: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]] - generate_one: RunnableMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]] - generate_iterator: RunnableMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]] - + SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu") + SUPPORTS_CPU_MULTI_THREADING = True + __call__: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]] + set_adapter: RunnableMethod[LLMRunnable[_M, _T], [str], dict[t.Literal["success", "error_msg"], bool | str]] + embeddings: RunnableMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings] + generate: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]] + generate_one: RunnableMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]] + generate_iterator: RunnableMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]] class LLMRunner(bentoml.Runner, t.Generic[_M, _T]): - __doc__: str - __module__: str - llm_type: str - identifying_params: dict[str, t.Any] - llm: openllm.LLM[_M, _T] - config: openllm.LLMConfig - implementation: LiteralRuntime - supports_embeddings: bool - supports_hf_agent: bool - has_adapters: bool - embeddings: RunnerMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings] - generate: RunnerMethod[LLMRunnable[_M, _T], [str], list[t.Any]] - generate_one: RunnerMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]] - generate_iterator: RunnerMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]] - def __init__( - self, - runnable_class: type[LLMRunnable[_M, _T]], - *, - runnable_init_params: dict[str, t.Any] | None = ..., - name: str | None = ..., - scheduling_strategy: type[Strategy] = ..., - models: list[bentoml.Model] | None = ..., - max_batch_size: int | None = ..., - max_latency_ms: int | None = ..., - method_configs: dict[str, dict[str, int]] | None = ..., - embedded: bool = False, - ) -> None: ... - def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ... - def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ... - def run(self, prompt: str, **attrs: t.Any) -> t.Any: ... - async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ... - def download_model(self) -> bentoml.Model: ... - @property - def peft_adapters(self) -> PeftAdapterOutput: ... - @property - def __repr_keys__(self) -> set[str]: ... + __doc__: str + __module__: str + llm_type: str + identifying_params: dict[str, t.Any] + llm: openllm.LLM[_M, _T] + config: openllm.LLMConfig + implementation: LiteralRuntime + supports_embeddings: bool + supports_hf_agent: bool + has_adapters: bool + embeddings: RunnerMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings] + generate: RunnerMethod[LLMRunnable[_M, _T], [str], list[t.Any]] + generate_one: RunnerMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]] + generate_iterator: RunnerMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]] + + def __init__( + self, runnable_class: type[LLMRunnable[_M, _T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False, + ) -> None: + ... + + def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: + ... + + def embed(self, prompt: str | list[str]) -> LLMEmbeddings: + ... + + def run(self, prompt: str, **attrs: t.Any) -> t.Any: + ... + + async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: + ... + + def download_model(self) -> bentoml.Model: + ... + + @property + def peft_adapters(self) -> PeftAdapterOutput: + ... + + @property + def __repr_keys__(self) -> set[str]: + ... diff --git a/src/openllm/bundle/__init__.py b/src/openllm/bundle/__init__.py index a7078136..2b5d7fc1 100644 --- a/src/openllm/bundle/__init__.py +++ b/src/openllm/bundle/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Build-related utilities. Some of these utilities are mainly used for 'openllm.build'. These utilities will stay internal, and its API can be changed or updated without backward-compatibility. @@ -23,20 +22,18 @@ import typing as t from . import oci as oci from ..utils import LazyModule -_import_structure: dict[str, list[str]] = { - "_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], - "oci": oci.__all__, -} +_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": oci.__all__,} if t.TYPE_CHECKING: - from . import _package as _package - from ._package import build_editable as build_editable - from ._package import construct_docker_options as construct_docker_options - from ._package import construct_python_options as construct_python_options - from ._package import create_bento as create_bento - from .oci import CONTAINER_NAMES as CONTAINER_NAMES - from .oci import build_container as build_container - from .oci import get_base_container_name as get_base_container_name - from .oci import get_base_container_tag as get_base_container_tag - from .oci import supported_registries as supported_registries -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from . import _package as _package + from ._package import build_editable as build_editable + from ._package import construct_docker_options as construct_docker_options + from ._package import construct_python_options as construct_python_options + from ._package import create_bento as create_bento + from .oci import CONTAINER_NAMES as CONTAINER_NAMES + from .oci import build_container as build_container + from .oci import get_base_container_name as get_base_container_name + from .oci import get_base_container_tag as get_base_container_tag + from .oci import supported_registries as supported_registries +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/bundle/_package.py b/src/openllm/bundle/_package.py index 94ef5cc6..3d6fde8e 100644 --- a/src/openllm/bundle/_package.py +++ b/src/openllm/bundle/_package.py @@ -45,200 +45,149 @@ from ..utils import is_torch_available from ..utils import pkg if t.TYPE_CHECKING: - from fs.base import FS + from fs.base import FS - import openllm - from bentoml._internal.bento import BentoStore - from bentoml._internal.models.model import ModelStore + import openllm + from bentoml._internal.bento import BentoStore + from bentoml._internal.models.model import ModelStore - from .oci import LiteralContainerRegistry - from .oci import LiteralContainerVersionStrategy + from .oci import LiteralContainerRegistry + from .oci import LiteralContainerVersionStrategy logger = logging.getLogger(__name__) OPENLLM_DEV_BUILD = "OPENLLM_DEV_BUILD" - def build_editable(path: str) -> str | None: - """Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set.""" - if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != "true": return None - # We need to build the package in editable mode, so that we can import it - from build import ProjectBuilder - from build.env import IsolatedEnvBuilder - module_location = pkg.source_locations("openllm") - if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.") - pyproject_path = Path(module_location).parent.parent / "pyproject.toml" - if os.path.isfile(pyproject_path.__fspath__()): - logger.info("OpenLLM is installed in editable mode. Generating built wheels...") - with IsolatedEnvBuilder() as env: - builder = ProjectBuilder(pyproject_path.parent) - builder.python_executable = env.executable - builder.scripts_dir = env.scripts_dir - env.install(builder.build_system_requires) - return builder.build("wheel", path, config_settings={"--global-option": "--quiet"}) - raise RuntimeError("Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.") + """Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set.""" + if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != "true": return None + # We need to build the package in editable mode, so that we can import it + from build import ProjectBuilder + from build.env import IsolatedEnvBuilder + module_location = pkg.source_locations("openllm") + if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.") + pyproject_path = Path(module_location).parent.parent / "pyproject.toml" + if os.path.isfile(pyproject_path.__fspath__()): + logger.info("OpenLLM is installed in editable mode. Generating built wheels...") + with IsolatedEnvBuilder() as env: + builder = ProjectBuilder(pyproject_path.parent) + builder.python_executable = env.executable + builder.scripts_dir = env.scripts_dir + env.install(builder.build_system_requires) + return builder.build("wheel", path, config_settings={"--global-option": "--quiet"}) + raise RuntimeError("Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.") -def construct_python_options( - llm: openllm.LLM[t.Any, t.Any], - llm_fs: FS, - extra_dependencies: tuple[str, ...] | None = None, - adapter_map: dict[str, str | None] | None = None, -) -> PythonOptions: - packages = ["openllm", "scipy"] # apparently bnb misses this one - if adapter_map is not None: packages += ["openllm[fine-tune]"] - # NOTE: add openllm to the default dependencies - # if users has openllm custom built wheels, it will still respect - # that since bentoml will always install dependencies from requirements.txt - # first, then proceed to install everything inside the wheels/ folder. - if extra_dependencies is not None: packages += [f"openllm[{k}]" for k in extra_dependencies] +def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str | None] | None = None,) -> PythonOptions: + packages = ["openllm", "scipy"] # apparently bnb misses this one + if adapter_map is not None: packages += ["openllm[fine-tune]"] + # NOTE: add openllm to the default dependencies + # if users has openllm custom built wheels, it will still respect + # that since bentoml will always install dependencies from requirements.txt + # first, then proceed to install everything inside the wheels/ folder. + if extra_dependencies is not None: packages += [f"openllm[{k}]" for k in extra_dependencies] - req = llm.config["requirements"] - if req is not None: packages.extend(req) - if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}") + req = llm.config["requirements"] + if req is not None: packages.extend(req) + if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}") - env = llm.config["env"] - framework_envvar = env["framework_value"] - if framework_envvar == "flax": - if not is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'") - packages.extend([importlib.metadata.version("flax"), importlib.metadata.version("jax"), importlib.metadata.version("jaxlib")]) - elif framework_envvar == "tf": - if not is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'") - candidates = ( - "tensorflow", - "tensorflow-cpu", - "tensorflow-gpu", - "tf-nightly", - "tf-nightly-cpu", - "tf-nightly-gpu", - "intel-tensorflow", - "intel-tensorflow-avx512", - "tensorflow-rocm", - "tensorflow-macos", - ) - # For the metadata, we have to look for both tensorflow and tensorflow-cpu - for candidate in candidates: - try: - pkgver = importlib.metadata.version(candidate) - if pkgver == candidate: packages.extend(["tensorflow"]) - else: - _tf_version = importlib.metadata.version(candidate) - packages.extend([f"tensorflow>={_tf_version}"]) - break - except importlib.metadata.PackageNotFoundError: pass - else: - if not is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.") - packages.extend([f'torch>={importlib.metadata.version("torch")}']) + env = llm.config["env"] + framework_envvar = env["framework_value"] + if framework_envvar == "flax": + if not is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'") + packages.extend([importlib.metadata.version("flax"), importlib.metadata.version("jax"), importlib.metadata.version("jaxlib")]) + elif framework_envvar == "tf": + if not is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'") + candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos",) + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for candidate in candidates: + try: + pkgver = importlib.metadata.version(candidate) + if pkgver == candidate: packages.extend(["tensorflow"]) + else: + _tf_version = importlib.metadata.version(candidate) + packages.extend([f"tensorflow>={_tf_version}"]) + break + except importlib.metadata.PackageNotFoundError: + pass + else: + if not is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.") + packages.extend([f'torch>={importlib.metadata.version("torch")}']) - wheels: list[str] = [] - built_wheels = build_editable(llm_fs.getsyspath("/")) - if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}")) - return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"]) + wheels: list[str] = [] + built_wheels = build_editable(llm_fs.getsyspath("/")) + if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}")) + return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"]) def construct_docker_options( - llm: openllm.LLM[t.Any, t.Any], - _: FS, - workers_per_resource: int | float, - quantize: t.LiteralString | None, - bettertransformer: bool | None, - adapter_map: dict[str, str | None] | None, - dockerfile_template: str | None, - runtime: t.Literal["ggml", "transformers"], - serialisation_format: t.Literal["safetensors", "legacy"], - container_registry: LiteralContainerRegistry, + llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy, ) -> DockerOptions: - _bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "") - _bentoml_config_options_opts = [ - "api_server.traffic.timeout=36000", # NOTE: Currently we hardcode this value - f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', - f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}', - ] - _bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts) - env: EnvVarMixin = llm.config["env"] - env_dict = { - env.framework: env.framework_value, - env.config: f"'{llm.config.model_dump_json().decode()}'", - "OPENLLM_MODEL": llm.config["model_name"], - "OPENLLM_SERIALIZATION": serialisation_format, - "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", - "OPENLLM_FAST": str(True), - "BENTOML_DEBUG": str(True), - "BENTOML_QUIET": str(False), - "OPENLLMDEVDEBUG": str(get_debug_mode()), - "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", - env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}", # This is the default BENTO_PATH var - } - if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1") + _bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "") + _bentoml_config_options_opts = [ + "api_server.traffic.timeout=36000", # NOTE: Currently we hardcode this value + f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}', + ] + _bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts) + env: EnvVarMixin = llm.config["env"] + env_dict = { + env.framework: env.framework_value, env.config: f"'{llm.config.model_dump_json().decode()}'", "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format, "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "OPENLLM_FAST": str(True), "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "OPENLLMDEVDEBUG": str(get_debug_mode()), + "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}", # This is the default BENTO_PATH var + } + if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1") - # We need to handle None separately here, as env from subprocess doesn't accept None value. - _env = EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime) + # We need to handle None separately here, as env from subprocess doesn't accept None value. + _env = EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime) - if _env.bettertransformer_value is not None: env_dict[_env.bettertransformer] = str(_env.bettertransformer_value) - if _env.quantize_value is not None: env_dict[_env.quantize] = _env.quantize_value - env_dict[_env.runtime] = _env.runtime_value - return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}",env=env_dict, dockerfile_template=dockerfile_template) + if _env.bettertransformer_value is not None: env_dict[_env.bettertransformer] = str(_env.bettertransformer_value) + if _env.quantize_value is not None: env_dict[_env.quantize] = _env.quantize_value + env_dict[_env.runtime] = _env.runtime_value + return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}", env=env_dict, dockerfile_template=dockerfile_template) @inject def create_bento( - bento_tag: bentoml.Tag, - llm_fs: FS, - llm: openllm.LLM[t.Any, t.Any], - workers_per_resource: str | int | float, - quantize: t.LiteralString | None, - bettertransformer: bool | None, - dockerfile_template: str | None, - adapter_map: dict[str, str | None] | None = None, - extra_dependencies: tuple[str, ...] | None = None, - runtime: t.Literal["ggml", "transformers"] = "transformers", - serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", - container_registry: LiteralContainerRegistry = "ecr", - container_version_strategy: LiteralContainerVersionStrategy = "release", - _bento_store: BentoStore = Provide[BentoMLContainer.bento_store], - _model_store: ModelStore = Provide[BentoMLContainer.model_store], + bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", + serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release", _bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store], ) -> bentoml.Bento: - framework_envvar = llm.config["env"]["framework_value"] - labels = dict(llm.identifying_params) - labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"}) - if adapter_map: labels.update(adapter_map) - if isinstance(workers_per_resource, str): - if workers_per_resource == "round_robin": workers_per_resource = 1.0 - elif workers_per_resource == "conserved": workers_per_resource = 1.0 if device_count() == 0 else float(1 / device_count()) - else: - try: workers_per_resource = float(workers_per_resource) - except ValueError: raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None - elif isinstance(workers_per_resource, int): workers_per_resource = float(workers_per_resource) + framework_envvar = llm.config["env"]["framework_value"] + labels = dict(llm.identifying_params) + labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"}) + if adapter_map: labels.update(adapter_map) + if isinstance(workers_per_resource, str): + if workers_per_resource == "round_robin": workers_per_resource = 1.0 + elif workers_per_resource == "conserved": workers_per_resource = 1.0 if device_count() == 0 else float(1 / device_count()) + else: + try: + workers_per_resource = float(workers_per_resource) + except ValueError: + raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None + elif isinstance(workers_per_resource, int): + workers_per_resource = float(workers_per_resource) - logger.info("Building Bento for '%s'", llm.config["start_name"]) - # add service.py definition to this temporary folder - codegen.write_service(llm, adapter_map, llm_fs) + logger.info("Building Bento for '%s'", llm.config["start_name"]) + # add service.py definition to this temporary folder + codegen.write_service(llm, adapter_map, llm_fs) - llm_spec = ModelSpec.from_item({"tag": str(llm.tag), "alias": llm.tag.name}) - build_config = BentoBuildConfig( - service=f"{llm.config['service_name']}:svc", - name=bento_tag.name, - labels=labels, - description=f"OpenLLM service for {llm.config['start_name']}", - include=list(llm_fs.walk.files()), - exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"], - python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map), - docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map, dockerfile_template, runtime, serialisation_format, container_registry, container_version_strategy), - models=[llm_spec], - ) + llm_spec = ModelSpec.from_item({"tag": str(llm.tag), "alias": llm.tag.name}) + build_config = BentoBuildConfig( + service=f"{llm.config['service_name']}:svc", name=bento_tag.name, labels=labels, description=f"OpenLLM service for {llm.config['start_name']}", include=list(llm_fs.walk.files()), exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"], python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map), + docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map, dockerfile_template, runtime, serialisation_format, container_registry, container_version_strategy), models=[llm_spec], + ) - bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath("/")) - # NOTE: the model_id_path here are only used for setting this environment variable within the container - # built with for BentoLLM. - service_fs_path = fs.path.join("src", llm.config["service_name"]) - service_path = bento._fs.getsyspath(service_fs_path) - with open(service_path, "r") as f: service_contents = f.readlines() + bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath("/")) + # NOTE: the model_id_path here are only used for setting this environment variable within the container + # built with for BentoLLM. + service_fs_path = fs.path.join("src", llm.config["service_name"]) + service_path = bento._fs.getsyspath(service_fs_path) + with open(service_path, "r") as f: + service_contents = f.readlines() - for it in service_contents: - if "__bento_name__" in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag)) + for it in service_contents: + if "__bento_name__" in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag)) - script = "".join(service_contents) - if DEBUG: logger.info("Generated script:\n%s", script) + script = "".join(service_contents) + if DEBUG: logger.info("Generated script:\n%s", script) - bento._fs.writetext(service_fs_path, script) - if "model_store" in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store) - # backward arguments. `model_store` is added recently - return bento.save(bento_store=_bento_store) + bento._fs.writetext(service_fs_path, script) + if "model_store" in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store) + # backward arguments. `model_store` is added recently + return bento.save(bento_store=_bento_store) diff --git a/src/openllm/bundle/oci/__init__.py b/src/openllm/bundle/oci/__init__.py index 1aaac70c..f806ad10 100644 --- a/src/openllm/bundle/oci/__init__.py +++ b/src/openllm/bundle/oci/__init__.py @@ -42,8 +42,7 @@ LiteralContainerVersionStrategy = t.Literal["release", "nightly", "latest"] # but in the future, we can infer based on git repo and everything to make it more options for users # to build the base image. For now, all of the base image will be /bentoml/openllm:... # NOTE: The ECR registry is the public one and currently only @bentoml team has access to push it. -_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {"docker": "docker.io/bentoml/openllm", "gh": "ghcr.io/bentoml/openllm", - "ecr": "public.ecr.aws/y5w8i4y6/bentoml/openllm"} +_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {"docker": "docker.io/bentoml/openllm", "gh": "ghcr.io/bentoml/openllm", "ecr": "public.ecr.aws/y5w8i4y6/bentoml/openllm"} _URI = "https://github.com/bentoml/openllm.git" @@ -51,56 +50,63 @@ _module_location = pkg.source_locations("openllm") @functools.lru_cache @apply(str.lower) -def get_base_container_name(reg: LiteralContainerRegistry) -> str: return _CONTAINER_REGISTRY[reg] +def get_base_container_name(reg: LiteralContainerRegistry) -> str: + return _CONTAINER_REGISTRY[reg] @functools.lru_cache(maxsize=1) -def _git() -> git.cmd.Git: return git.cmd.Git(_URI) +def _git() -> git.cmd.Git: + return git.cmd.Git(_URI) @functools.lru_cache -def _nightly_ref() -> tuple[str, str]: return _git().ls_remote(_URI, "main", heads=True).split() +def _nightly_ref() -> tuple[str, str]: + return _git().ls_remote(_URI, "main", heads=True).split() @functools.lru_cache -def _stable_ref() -> tuple[str, str]: return max([item.split() for item in _git().ls_remote(_URI, refs=True, tags=True).split("\n")], - key = lambda tag: tuple(int(k) for k in tag[-1].replace("refs/tags/v", "").split("."))) +def _stable_ref() -> tuple[str, str]: + return max([item.split() for item in _git().ls_remote(_URI, refs=True, tags=True).split("\n")], key=lambda tag: tuple(int(k) for k in tag[-1].replace("refs/tags/v", "").split("."))) def get_base_container_tag(strategy: LiteralContainerVersionStrategy) -> str: - if strategy == "release": return _stable_ref()[-1].replace("refs/tags/v", "") # for stable, we can also use latest, but discouraged - elif strategy == "latest": return "latest" - elif strategy == "nightly": return f"sha-{_nightly_ref()[0][:7]}" # we prefixed with sha- (giv_rev[:7]) - else: raise ValueError(f"Unknown strategy '{strategy}'. Valid strategies are 'release', 'nightly', and 'latest'") + if strategy == "release": return _stable_ref()[-1].replace("refs/tags/v", "") # for stable, we can also use latest, but discouraged + elif strategy == "latest": return "latest" + elif strategy == "nightly": return f"sha-{_nightly_ref()[0][:7]}" # we prefixed with sha- (giv_rev[:7]) + else: raise ValueError(f"Unknown strategy '{strategy}'. Valid strategies are 'release', 'nightly', and 'latest'") def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None, version_strategy: LiteralContainerVersionStrategy = "release", push: bool = False, machine: bool = False) -> dict[str | LiteralContainerRegistry, str]: - """This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed. + """This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed. - Note that this is useful for debugging or for any users who wish to integrate vertically with OpenLLM. For most users, you should be able to get the image either from GitHub Container Registry or our public ECR registry. - """ - try: - if not _BUILDER.health(): raise Error - except (Error, subprocess.CalledProcessError): raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None - if device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)") - if not shutil.which("nvidia-container-runtime"): raise RuntimeError("Make sure to have NVIDIA Container Toolkit setup correctly to compile CUDA kernel in container. See https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html for more information.") - if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)") - pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml" - if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'") - tags: dict[str | LiteralContainerRegistry, str] - if not registries: tags = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy - else: - registries = [registries] if isinstance(registries, str) else list(registries) - tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries} - try: - outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()), - push=push, progress="plain" if get_debug_mode() else "auto", quiet=machine) - if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip() - except Exception as err: raise OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err - return tags + Note that this is useful for debugging or for any users who wish to integrate vertically with OpenLLM. For most users, you should be able to get the image either from GitHub Container Registry or our public ECR registry. + """ + try: + if not _BUILDER.health(): raise Error + except (Error, subprocess.CalledProcessError): + raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None + if device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)") + if not shutil.which("nvidia-container-runtime"): raise RuntimeError("NVIDIA Container Toolkit is required to compile CUDA kernel in container.") + if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)") + pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml" + if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'") + if t.TYPE_CHECKING: tags: dict[str | LiteralContainerRegistry, str] + if not registries: tags = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy + else: + registries = [registries] if isinstance(registries, str) else list(registries) + tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries} + try: + outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()), push=push, progress="plain" if get_debug_mode() else "auto", quiet=machine) + if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip() + except Exception as err: + raise OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err + return tags if t.TYPE_CHECKING: - CONTAINER_NAMES: dict[LiteralContainerRegistry, str] - supported_registries: list[str] + CONTAINER_NAMES: dict[LiteralContainerRegistry, str] + supported_registries: list[str] __all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries"] -def __dir__() -> list[str]: return sorted(__all__) + +def __dir__() -> list[str]: + return sorted(__all__) + def __getattr__(name: str) -> t.Any: - if name == "supported_registries": return functools.lru_cache(1)(lambda _: list(_CONTAINER_REGISTRY))() - elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY - elif name in __all__: return importlib.import_module("." + name, __name__) - else: raise AttributeError(f"{name} does not exists under {__name__}") + if name == "supported_registries": return functools.lru_cache(1)(lambda _: list(_CONTAINER_REGISTRY))() + elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY + elif name in __all__: return importlib.import_module("." + name, __name__) + else: raise AttributeError(f"{name} does not exists under {__name__}") diff --git a/src/openllm/cli/__init__.py b/src/openllm/cli/__init__.py index d3ecb9a2..52cdfb76 100644 --- a/src/openllm/cli/__init__.py +++ b/src/openllm/cli/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """OpenLLM CLI. For more information see ``openllm -h``. diff --git a/src/openllm/cli/_factory.py b/src/openllm/cli/_factory.py index a9c2e631..cef26bd5 100644 --- a/src/openllm/cli/_factory.py +++ b/src/openllm/cli/_factory.py @@ -47,71 +47,64 @@ from ..utils import is_peft_available from ..utils import resolve_user_filepath if t.TYPE_CHECKING: - import subprocess + import subprocess - from .._configuration import LLMConfig - from .._types import DictStrAny - from .._types import P - TupleStr = tuple[str, ...] -else: TupleStr = tuple + from .._configuration import LLMConfig + from .._types import DictStrAny + from .._types import P + TupleStr = tuple[str, ...] +else: + TupleStr = tuple LiteralOutput = t.Literal["json", "pretty", "porcelain"] _AnyCallable = t.Callable[..., t.Any] FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, click.Command]) -def parse_config_options( - config: LLMConfig, - server_timeout: int, - workers_per_resource: float, - device: tuple[str, ...] | None, - environ: DictStrAny, -) -> DictStrAny: - # TODO: Support amd.com/gpu on k8s - _bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "") - _bentoml_config_options_opts = ["tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}'] - if device: - if len(device) > 1: _bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)]) - else: _bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]') - _bentoml_config_options_env += " " if _bentoml_config_options_env else "" + " ".join(_bentoml_config_options_opts) - environ["BENTOML_CONFIG_OPTIONS"] = _bentoml_config_options_env - return environ +def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: tuple[str, ...] | None, environ: DictStrAny,) -> DictStrAny: + # TODO: Support amd.com/gpu on k8s + _bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "") + _bentoml_config_options_opts = ["tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}'] + if device: + if len(device) > 1: _bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)]) + else: _bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]') + _bentoml_config_options_env += " " if _bentoml_config_options_env else "" + " ".join(_bentoml_config_options_opts) + environ["BENTOML_CONFIG_OPTIONS"] = _bentoml_config_options_env + return environ _adapter_mapping_key = "adapter_map" -def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> None: - if not value: return None - if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {} - for v in value: - adapter_id, *adapter_name = v.rsplit(":", maxsplit=1) - # try to resolve the full path if users pass in relative, - # currently only support one level of resolve path with current directory - try: adapter_id = resolve_user_filepath(adapter_id, os.getcwd()) - except FileNotFoundError: pass - ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None - return None +def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> None: + if not value: return None + if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {} + for v in value: + adapter_id, *adapter_name = v.rsplit(":", maxsplit=1) + # try to resolve the full path if users pass in relative, + # currently only support one level of resolve path with current directory + try: + adapter_id = resolve_user_filepath(adapter_id, os.getcwd()) + except FileNotFoundError: + pass + ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None + return None def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command: - """Generate a 'click.Command' for any given LLM. + """Generate a 'click.Command' for any given LLM. - Args: - group: the target ``click.Group`` to save this LLM cli under - model: The name of the model or the ``bentoml.Bento`` instance. + Args: + group: the target ``click.Group`` to save this LLM cli under + model: The name of the model or the ``bentoml.Bento`` instance. - Returns: - The click.Command for starting the model server + Returns: + The click.Command for starting the model server - Note that the internal commands will return the llm_config and a boolean determine - whether the server is run with GPU or not. - """ - llm_config = AutoConfig.for_model(model) + Note that the internal commands will return the llm_config and a boolean determine + whether the server is run with GPU or not. + """ + llm_config = AutoConfig.for_model(model) - command_attrs: DictStrAny = dict( - name=llm_config["model_name"], - context_settings=_context_settings or termui.CONTEXT_SETTINGS, - short_help=f"Start a LLMServer for '{model}'", - aliases=[llm_config["start_name"]] if llm_config["name_type"] == "dasherize" else None, - help=f"""\ + command_attrs: DictStrAny = dict( + name=llm_config["model_name"], context_settings=_context_settings or termui.CONTEXT_SETTINGS, short_help=f"Start a LLMServer for '{model}'", aliases=[llm_config["start_name"]] if llm_config["name_type"] == "dasherize" else None, help=f"""\ {llm_config['env'].start_docstring} \b @@ -130,142 +123,125 @@ Available official model_id(s): [default: {llm_config['default_id']}] \b {orjson.dumps(llm_config['model_ids'], option=orjson.OPT_INDENT_2).decode()} """, - ) + ) - if llm_config["requires_gpu"] and device_count() < 1: - # NOTE: The model requires GPU, therefore we will return a dummy command - command_attrs.update({"short_help": "(Disabled because there is no GPU available)", "help": f"""{model} is currently not available to run on your local machine because it requires GPU for inference."""}) - return noop_command(group, llm_config, _serve_grpc, **command_attrs) + if llm_config["requires_gpu"] and device_count() < 1: + # NOTE: The model requires GPU, therefore we will return a dummy command + command_attrs.update({"short_help": "(Disabled because there is no GPU available)", "help": f"""{model} is currently not available to run on your local machine because it requires GPU for inference."""}) + return noop_command(group, llm_config, _serve_grpc, **command_attrs) + @group.command(**command_attrs) + @start_decorator(llm_config, serve_grpc=_serve_grpc) + @click.pass_context + def start_cmd( + ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, workers_per_resource: t.Literal["conserved", "round_robin"] | t.LiteralString, device: tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool, + serialisation_format: t.Literal["safetensors", "legacy"], adapter_id: str | None, return_process: bool, **attrs: t.Any, + ) -> LLMConfig | subprocess.Popen[bytes]: + fast = str(fast).upper() in ENV_VARS_TRUE_VALUES + if serialisation_format == "safetensors" and quantize is not None and os.getenv("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in ENV_VARS_TRUE_VALUES: + termui.echo(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.", fg="yellow") + adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None) + config, server_attrs = llm_config.model_validate_click(**attrs) + server_timeout = first_not_none(server_timeout, default=config["timeout"]) + server_attrs.update({"working_dir": os.path.dirname(os.path.dirname(__file__)), "timeout": server_timeout}) + if _serve_grpc: server_attrs["grpc_protocol_version"] = "v1" + # NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream. + development = server_attrs.pop("development") + server_attrs.setdefault("production", not development) + wpr = first_not_none(workers_per_resource, default=config["workers_per_resource"]) - @group.command(**command_attrs) - @start_decorator(llm_config, serve_grpc=_serve_grpc) - @click.pass_context - def start_cmd( - ctx: click.Context, /, - server_timeout: int, - model_id: str | None, - model_version: str | None, - workers_per_resource: t.Literal["conserved", "round_robin"] | t.LiteralString, - device: tuple[str, ...], - quantize: t.Literal["int8", "int4", "gptq"] | None, - bettertransformer: bool | None, - runtime: t.Literal["ggml", "transformers"], - fast: bool, - serialisation_format: t.Literal["safetensors", "legacy"], - adapter_id: str | None, - return_process: bool, - **attrs: t.Any, - ) -> LLMConfig | subprocess.Popen[bytes]: - fast = str(fast).upper() in ENV_VARS_TRUE_VALUES - if serialisation_format == "safetensors" and quantize is not None and os.getenv("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in ENV_VARS_TRUE_VALUES: - termui.echo(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.", fg="yellow") - adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None) - config, server_attrs = llm_config.model_validate_click(**attrs) - server_timeout = first_not_none(server_timeout, default=config["timeout"]) - server_attrs.update({"working_dir": os.path.dirname(os.path.dirname(__file__)), "timeout": server_timeout}) - if _serve_grpc: server_attrs["grpc_protocol_version"] = "v1" - # NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream. - development = server_attrs.pop("development") - server_attrs.setdefault("production", not development) - wpr = first_not_none(workers_per_resource, default=config["workers_per_resource"]) - - if isinstance(wpr, str): - if wpr == "round_robin": wpr = 1.0 - elif wpr == "conserved": - if device and device_count() == 0: - termui.echo("--device will have no effect as there is no GPUs available", fg="yellow") - wpr = 1.0 - else: - available_gpu = len(device) if device else device_count() - wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu) - else: wpr = float(wpr) - elif isinstance(wpr, int): wpr = float(wpr) - - # Create a new model env to work with the envvar during CLI invocation - env = EnvVarMixin(config["model_name"], config.default_implementation(), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime) - prerequisite_check(ctx, config, quantize, adapter_map, int(1 / wpr)) - - # NOTE: This is to set current configuration - start_env = os.environ.copy() - start_env = parse_config_options(config, server_timeout, wpr, device, start_env) - if fast: termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg="yellow") - - start_env.update( - { - "OPENLLM_MODEL": model, - "BENTOML_DEBUG": str(get_debug_mode()), - "BENTOML_HOME": os.getenv("BENTOML_HOME", BentoMLContainer.bentoml_home.get()), - "OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(), - "OPENLLM_SERIALIZATION": serialisation_format, - env.runtime: env.runtime_value, - env.framework: env.framework_value, - } - ) - if env.model_id_value: start_env[env.model_id] = str(env.model_id_value) - # NOTE: quantize and bettertransformer value is already assigned within env - if bettertransformer is not None: start_env[env.bettertransformer] = str(env.bettertransformer_value) - if quantize is not None: start_env[env.quantize] = str(env.quantize_value) - - llm = infer_auto_class(env.framework_value).for_model(model, model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format) - start_env.update({env.config: llm.config.model_dump_json().decode(), env.model_id: llm.model_id}) - - server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs) - analytics.track_start_init(llm.config) - - def next_step(model_name: str, adapter_map: DictStrAny | None) -> None: - cmd_name = f"openllm build {model_name}" - if adapter_map is not None: cmd_name += " " + " ".join([f"--adapter-id {s}" for s in [f"{p}:{name}" if name not in (None, "default") else p for p, name in adapter_map.items()]]) - if not get_quiet_mode(): termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg="blue") - - if return_process: - server.start(env=start_env, text=True) - if server.process is None: raise click.ClickException("Failed to start the server.") - return server.process + if isinstance(wpr, str): + if wpr == "round_robin": wpr = 1.0 + elif wpr == "conserved": + if device and device_count() == 0: + termui.echo("--device will have no effect as there is no GPUs available", fg="yellow") + wpr = 1.0 else: - try: server.start(env=start_env, text=True, blocking=True) - except KeyboardInterrupt: next_step(model, adapter_map) - except Exception as err: termui.echo(f"Error caught while running LLM Server:\n{err}", fg="red") - else: next_step(model, adapter_map) + available_gpu = len(device) if device else device_count() + wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu) + else: + wpr = float(wpr) + elif isinstance(wpr, int): + wpr = float(wpr) - # NOTE: Return the configuration for telemetry purposes. - return config + # Create a new model env to work with the envvar during CLI invocation + env = EnvVarMixin(config["model_name"], config.default_implementation(), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime) + prerequisite_check(ctx, config, quantize, adapter_map, int(1 / wpr)) - return start_cmd + # NOTE: This is to set current configuration + start_env = os.environ.copy() + start_env = parse_config_options(config, server_timeout, wpr, device, start_env) + if fast: termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg="yellow") + + start_env.update({"OPENLLM_MODEL": model, "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_HOME": os.getenv("BENTOML_HOME", BentoMLContainer.bentoml_home.get()), "OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(), "OPENLLM_SERIALIZATION": serialisation_format, env.runtime: env.runtime_value, env.framework: env.framework_value,}) + if env.model_id_value: start_env[env.model_id] = str(env.model_id_value) + # NOTE: quantize and bettertransformer value is already assigned within env + if bettertransformer is not None: start_env[env.bettertransformer] = str(env.bettertransformer_value) + if quantize is not None: start_env[env.quantize] = str(env.quantize_value) + + llm = infer_auto_class(env.framework_value).for_model(model, model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format) + start_env.update({env.config: llm.config.model_dump_json().decode(), env.model_id: llm.model_id}) + + server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs) + analytics.track_start_init(llm.config) + + def next_step(model_name: str, adapter_map: DictStrAny | None) -> None: + cmd_name = f"openllm build {model_name}" + if adapter_map is not None: cmd_name += " " + " ".join([f"--adapter-id {s}" for s in [f"{p}:{name}" if name not in (None, "default") else p for p, name in adapter_map.items()]]) + if not get_quiet_mode(): termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg="blue") + + if return_process: + server.start(env=start_env, text=True) + if server.process is None: raise click.ClickException("Failed to start the server.") + return server.process + else: + try: + server.start(env=start_env, text=True, blocking=True) + except KeyboardInterrupt: + next_step(model, adapter_map) + except Exception as err: + termui.echo(f"Error caught while running LLM Server:\n{err}", fg="red") + else: + next_step(model, adapter_map) + + # NOTE: Return the configuration for telemetry purposes. + return config + + return start_cmd def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, **command_attrs: t.Any) -> click.Command: - context_settings = command_attrs.pop("context_settings", {}) - context_settings.update({"ignore_unknown_options": True, "allow_extra_args": True}) - command_attrs["context_settings"] = context_settings - # NOTE: The model requires GPU, therefore we will return a dummy command - @group.command(**command_attrs) - def noop(**_: t.Any) -> LLMConfig: - termui.echo("No GPU available, therefore this command is disabled", fg="red") - analytics.track_start_init(llm_config) - return llm_config - return noop + context_settings = command_attrs.pop("context_settings", {}) + context_settings.update({"ignore_unknown_options": True, "allow_extra_args": True}) + command_attrs["context_settings"] = context_settings + # NOTE: The model requires GPU, therefore we will return a dummy command + @group.command(**command_attrs) + def noop(**_: t.Any) -> LLMConfig: + termui.echo("No GPU available, therefore this command is disabled", fg="red") + analytics.track_start_init(llm_config) + return llm_config + + return noop def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: t.LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None: - if adapter_map and not is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'") - if quantize and llm_config.default_implementation() == "vllm": ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config.env['framework']}=\"pt\"' to run with quantization.") - requirements = llm_config["requirements"] - if requirements is not None and len(requirements) > 0: - missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None] - if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow") + if adapter_map and not is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'") + if quantize and llm_config.default_implementation() == "vllm": ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config.env['framework']}=\"pt\"' to run with quantization.") + requirements = llm_config["requirements"] + if requirements is not None and len(requirements) > 0: + missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None] + if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow") def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]: - return lambda fn: compose(*[ - llm_config.to_click_options, - _http_server_args if not serve_grpc else _grpc_server_args, - cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."), - model_id_option(factory=cog.optgroup, model_env=llm_config["env"]), - model_version_option(factory=cog.optgroup), - cog.optgroup.option("--server-timeout", type=int, default=None, help="Server timeout in seconds"), - workers_per_resource_option(factory=cog.optgroup), - fast_option(factory=cog.optgroup), - cog.optgroup.group( - "LLM Optimization Options", - help="""Optimization related options. + return lambda fn: compose( + *[ + llm_config.to_click_options, _http_server_args if not serve_grpc else _grpc_server_args, + cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."), + model_id_option(factory=cog.optgroup, model_env=llm_config["env"]), + model_version_option(factory=cog.optgroup), + cog.optgroup.option("--server-timeout", type=int, default=None, help="Server timeout in seconds"), + workers_per_resource_option(factory=cog.optgroup), + fast_option(factory=cog.optgroup), + cog.optgroup.group( + "LLM Optimization Options", help="""Optimization related options. OpenLLM supports running model with [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/), k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM. @@ -275,15 +251,14 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab - DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/) - GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml) """, - ), - cog.optgroup.option("--device", type=dantic.CUDA, multiple=True, envvar="CUDA_VISIBLE_DEVICES", callback=parse_device_callback, help=f"Assign GPU devices (if available) for {llm_config['model_name']}.", show_envvar=True), - cog.optgroup.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers."), - quantize_option(factory=cog.optgroup, model_env=llm_config["env"]), - bettertransformer_option(factory=cog.optgroup, model_env=llm_config["env"]), - serialisation_option(factory=cog.optgroup), - cog.optgroup.group( - "Fine-tuning related options", - help="""\ + ), + cog.optgroup.option("--device", type=dantic.CUDA, multiple=True, envvar="CUDA_VISIBLE_DEVICES", callback=parse_device_callback, help=f"Assign GPU devices (if available) for {llm_config['model_name']}.", show_envvar=True), + cog.optgroup.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers."), + quantize_option(factory=cog.optgroup, model_env=llm_config["env"]), + bettertransformer_option(factory=cog.optgroup, model_env=llm_config["env"]), + serialisation_option(factory=cog.optgroup), + cog.optgroup.group( + "Fine-tuning related options", help="""\ Note that the argument `--adapter-id` can accept the following format: - `--adapter-id /path/to/adapter` (local adapter) @@ -298,18 +273,19 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab ``` """, - ), - cog.optgroup.option("--adapter-id", default=None, help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'", multiple=True, callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"), - click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True), - ])(fn) + ), + cog.optgroup.option("--adapter-id", default=None, help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'", multiple=True, callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"), + click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True), + ] + )(fn) def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> TupleStr | None: - if value is None: return value - if not LazyType(TupleStr).isinstance(value): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})") - el: TupleStr = tuple(i for k in value for i in k) - # NOTE: --device all is a special case - if len(el) == 1 and el[0] == "all": return tuple(map(str, available_devices())) - return el + if value is None: return value + if not LazyType(TupleStr).isinstance(value): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})") + el: TupleStr = tuple(i for k in value for i in k) + # NOTE: --device all is a special case + if len(el) == 1 and el[0] == "all": return tuple(map(str, available_devices())) + return el # NOTE: A list of bentoml option that is not needed for parsing. # NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this. @@ -317,67 +293,78 @@ def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tup _IGNORED_OPTIONS = {"working_dir", "production", "protocol_version"} def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]: - """Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`.""" - from bentoml_cli.cli import cli + """Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`.""" + from bentoml_cli.cli import cli - command = "serve" if not serve_grpc else "serve-grpc" - group = cog.optgroup.group( - f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options", - help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]", - ) + command = "serve" if not serve_grpc else "serve-grpc" + group = cog.optgroup.group(f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options", help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",) - def decorator(f: t.Callable[t.Concatenate[int, str | None, P], LLMConfig]) -> t.Callable[[FC], FC]: - serve_command = cli.commands[command] - # The first variable is the argument bento - # The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS - serve_options = [p for p in serve_command.params[1 :-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS] - for options in reversed(serve_options): - attrs = options.to_info_dict() - # we don't need param_type_name, since it should all be options - attrs.pop("param_type_name") - # name is not a valid args - attrs.pop("name") - # type can be determine from default value - attrs.pop("type") - param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts")) - f = cog.optgroup.option(*param_decls, **attrs)(f) - return group(f) - return decorator + def decorator(f: t.Callable[t.Concatenate[int, str | None, P], LLMConfig]) -> t.Callable[[FC], FC]: + serve_command = cli.commands[command] + # The first variable is the argument bento + # The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS + serve_options = [p for p in serve_command.params[1:-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS] + for options in reversed(serve_options): + attrs = options.to_info_dict() + # we don't need param_type_name, since it should all be options + attrs.pop("param_type_name") + # name is not a valid args + attrs.pop("name") + # type can be determine from default value + attrs.pop("type") + param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts")) + f = cog.optgroup.option(*param_decls, **attrs)(f) + return group(f) + + return decorator _http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True) def cli_option(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]: - """General ``@click.option`` with some sauce. + """General ``@click.option`` with some sauce. - This decorator extends the default ``@click.option`` plus a factory option to use which type of option, for example: [click, click_option_group.optgroup] - """ - attrs.setdefault("help", "General option for OpenLLM CLI.") - factory = attrs.pop("factory", click) - def decorator(f: FC | None) -> FC: return t.cast(FC, factory.option(*param_decls, **attrs)(f) if f is not None else factory.option(*param_decls, **attrs)) - return decorator + This decorator extends the default ``@click.option`` plus a factory option to use which type of option, for example: [click, click_option_group.optgroup] + """ + attrs.setdefault("help", "General option for OpenLLM CLI.") + factory = attrs.pop("factory", click) + + def decorator(f: FC | None) -> FC: + return t.cast(FC, factory.option(*param_decls, **attrs)(f) if f is not None else factory.option(*param_decls, **attrs)) + + return decorator def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = "pretty", **attrs: t.Any) -> t.Callable[[FC], FC]: - output = ["json", "pretty", "porcelain"] - def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]: return [CompletionItem(it) for it in output] - return cli_option("-o", "--output", "output", type=click.Choice(output), default=default_value, help="Showing output type.", show_default=True, - envvar="OPENLLM_OUTPUT", show_envvar=True, shell_complete=complete_output_var, **attrs)(f) -def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--fast/--no-fast", show_default=True, default=False, envvar="OPENLLM_USE_LOCAL_LATEST", show_envvar=True, - help="""Whether to skip checking if models is already in store. + output = ["json", "pretty", "porcelain"] + + def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]: + return [CompletionItem(it) for it in output] + + return cli_option("-o", "--output", "output", type=click.Choice(output), default=default_value, help="Showing output type.", show_default=True, envvar="OPENLLM_OUTPUT", show_envvar=True, shell_complete=complete_output_var, **attrs)(f) + +def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option( + "--fast/--no-fast", show_default=True, default=False, envvar="OPENLLM_USE_LOCAL_LATEST", show_envvar=True, help="""Whether to skip checking if models is already in store. This is useful if you already downloaded or setup the model beforehand. - """, **attrs)(f) -def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f) -def model_id_option(f: _AnyCallable | None = None, *, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, - show_envvar=model_env is not None, - help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f) -def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-version", type=click.STRING, default=None, - help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f) + """, **attrs + )(f) + +def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f) + +def model_id_option(f: _AnyCallable | None = None, *, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f) + +def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f) + def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]: - arg = click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in CONFIG_MAPPING]), required=required) - return arg(f) if f is not None else arg -def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--quantise", "--quantize", "quantize", type=click.Choice(["int8", "int4", "gptq"]), default=None, - envvar=model_env.quantize if model_env is not None else None, show_envvar=model_env is not None, - help="""Dynamic quantization for running this LLM. + arg = click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in CONFIG_MAPPING]), required=required) + return arg(f) if f is not None else arg + +def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option( + "--quantise", "--quantize", "quantize", type=click.Choice(["int8", "int4", "gptq"]), default=None, envvar=model_env.quantize if model_env is not None else None, show_envvar=model_env is not None, help="""Dynamic quantization for running this LLM. The following quantization strategies are supported: @@ -388,11 +375,16 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model - ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323) **Note** that the model can also be served with quantized weights. - """ + (""" - **Note** that this will set the mode for serving within deployment.""" if build else "") + """ - **Note** that quantization are currently only available in *PyTorch* models.""", **attrs)(f) -def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--workers-per-resource", default=None, callback=workers_per_resource_callback, type=str, required=False, - help="""Number of workers per resource assigned. + """ + ( + """ + **Note** that this will set the mode for serving within deployment.""" if build else "" + ) + """ + **Note** that quantization are currently only available in *PyTorch* models.""", **attrs + )(f) + +def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option( + "--workers-per-resource", default=None, callback=workers_per_resource_callback, type=str, required=False, help="""Number of workers per resource assigned. See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy for more information. By default, this is set to 1. @@ -402,16 +394,22 @@ def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. - ``conserved``: This will determine the number of available GPU resources, and only assign one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``. - """ + ("""\n + """ + ( + """\n **Note**: The workers value passed into 'build' will determine how the LLM can be provisioned in Kubernetes as well as in standalone container. This will - ensure it has the same effect with 'openllm start --workers ...'""" if build else ""), **attrs)(f) -def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--bettertransformer", is_flag=True, default=None, envvar=model_env.bettertransformer if model_env is not None else None, show_envvar=model_env is not None, - help="Apply FasterTransformer wrapper to serve model. This will applies during serving time." if not build else "Set default environment variable whether to serve this model with FasterTransformer in build time.", - **attrs)(f) -def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--serialisation", "--serialization", "serialisation_format", type=click.Choice(["safetensors", "legacy"]), - default="safetensors", show_default=True, show_envvar=True, envvar="OPENLLM_SERIALIZATION", - help="""Serialisation format for save/load LLM. + ensure it has the same effect with 'openllm start --workers ...'""" if build else "" + ), **attrs + )(f) + +def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option( + "--bettertransformer", is_flag=True, default=None, envvar=model_env.bettertransformer if model_env is not None else None, show_envvar=model_env is not None, help="Apply FasterTransformer wrapper to serve model. This will applies during serving time." if not build else "Set default environment variable whether to serve this model with FasterTransformer in build time.", **attrs + )(f) + +def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option( + "--serialisation", "--serialization", "serialisation_format", type=click.Choice(["safetensors", "legacy"]), default="safetensors", show_default=True, show_envvar=True, envvar="OPENLLM_SERIALIZATION", help="""Serialisation format for save/load LLM. Currently the following strategies are supported: @@ -429,29 +427,34 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal **Note** that GGML format is working in progress. """, **attrs - )(f) -def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--container-registry", "container_registry", type=str, default="ecr", show_default=True, show_envvar=True, envvar="OPENLLM_CONTAINER_REGISTRY", - callback=container_registry_callback, - help="""The default container registry to get the base image for building BentoLLM. + )(f) + +def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option( + "--container-registry", "container_registry", type=str, default="ecr", show_default=True, show_envvar=True, envvar="OPENLLM_CONTAINER_REGISTRY", callback=container_registry_callback, help="""The default container registry to get the base image for building BentoLLM. Currently, it supports 'ecr', 'ghcr.io', 'docker.io' \b **Note** that in order to build the base image, you will need a GPUs to compile custom kernel. See ``openllm ext build-base-container`` for more information. - """)(f) + """ + )(f) _wpr_strategies = {"round_robin", "conserved"} def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None: - if value is None: return value - value = inflection.underscore(value) - if value in _wpr_strategies: return value + if value is None: return value + value = inflection.underscore(value) + if value in _wpr_strategies: return value + else: + try: + float(value) # type: ignore[arg-type] + except ValueError: + raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None else: - try: float(value) # type: ignore[arg-type] - except ValueError: raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None - else: return value + return value def container_registry_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None: - if value is None: return value - if value not in bundle.supported_registries: raise click.BadParameter(f"Value must be one of {bundle.supported_registries}", ctx, param) - return value + if value is None: return value + if value not in bundle.supported_registries: raise click.BadParameter(f"Value must be one of {bundle.supported_registries}", ctx, param) + return value diff --git a/src/openllm/cli/entrypoint.py b/src/openllm/cli/entrypoint.py index 41f41a82..bb903cc6 100644 --- a/src/openllm/cli/entrypoint.py +++ b/src/openllm/cli/entrypoint.py @@ -122,29 +122,30 @@ from ..utils import set_debug_mode from ..utils import set_quiet_mode if t.TYPE_CHECKING: - import jupytext - import nbformat - import torch + import jupytext + import nbformat + import torch - from bentoml._internal.bento import BentoStore - from bentoml._internal.container import DefaultBuilder + from bentoml._internal.bento import BentoStore + from bentoml._internal.container import DefaultBuilder - from .._configuration import LLMConfig - from .._schema import EmbeddingsOutput - from .._types import DictStrAny - from .._types import ListStr - from .._types import LiteralRuntime - from .._types import P - from ..bundle.oci import LiteralContainerRegistry - from ..bundle.oci import LiteralContainerVersionStrategy -else: torch, jupytext, nbformat = LazyLoader("torch", globals(), "torch"), LazyLoader("jupytext", globals(), "jupytext"), LazyLoader("nbformat", globals(), "nbformat") + from .._configuration import LLMConfig + from .._schema import EmbeddingsOutput + from .._types import DictStrAny + from .._types import ListStr + from .._types import LiteralRuntime + from .._types import P + from ..bundle.oci import LiteralContainerRegistry + from ..bundle.oci import LiteralContainerVersionStrategy +else: + torch, jupytext, nbformat = LazyLoader("torch", globals(), "torch"), LazyLoader("jupytext", globals(), "jupytext"), LazyLoader("nbformat", globals(), "nbformat") # NOTE: We need to do this so that overload can register # correct overloads to typing registry if sys.version_info[:2] >= (3, 11): - from typing import overload + from typing import overload else: - from typing_extensions import overload + from typing_extensions import overload logger = logging.getLogger(__name__) @@ -161,8 +162,10 @@ ServeCommand = t.Literal["serve", "serve-grpc"] @attr.define class GlobalOptions: - cloud_context: str | None = attr.field(default=None, converter=attr.converters.default_if_none("default")) - def with_options(self, **attrs: t.Any) -> t.Self: return attr.evolve(self, **attrs) + cloud_context: str | None = attr.field(default=None, converter=attr.converters.default_if_none("default")) + + def with_options(self, **attrs: t.Any) -> t.Self: + return attr.evolve(self, **attrs) CmdType = t.TypeVar("CmdType", bound=click.Command) GrpType = t.TypeVar("GrpType", bound=click.Group) @@ -170,183 +173,219 @@ GrpType = t.TypeVar("GrpType", bound=click.Group) _object_setattr = object.__setattr__ class OpenLLMCommandGroup(BentoMLCommandGroup): - NUMBER_OF_COMMON_PARAMS = 5 # parameters in common_params + 1 faked group option header - @staticmethod - def common_params(f: t.Callable[P, t.Any]) -> t.Callable[[FC], FC]: - # The following logics is similar to one of BentoMLCommandGroup - @cog.optgroup.group("Global options") - @cog.optgroup.option("-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output.", show_envvar=True) - @cog.optgroup.option( "--debug", "--verbose", "debug", envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help="Print out debug logs.", show_envvar=True) - @cog.optgroup.option("--do-not-track", is_flag=True, default=False, envvar=analytics.OPENLLM_DO_NOT_TRACK, help="Do not send usage info", show_envvar=True) - @cog.optgroup.option( "--context", "cloud_context", envvar="BENTOCLOUD_CONTEXT", type=click.STRING, default=None, help="BentoCloud context name.", show_envvar=True) - @click.pass_context - @functools.wraps(f) - def wrapper(ctx: click.Context, quiet: bool, debug: bool, cloud_context: str | None, *args: P.args, **attrs: P.kwargs) -> t.Any: - ctx.obj = GlobalOptions(cloud_context=cloud_context) - if quiet: - set_quiet_mode(True) - if debug: logger.warning("'--quiet' passed; ignoring '--verbose/--debug'") - elif debug: set_debug_mode(True) - configure_logging() - return f(*args, **attrs) - return wrapper - @staticmethod - def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[t.Concatenate[bool, P], t.Any]: - command_name = attrs.get("name", func.__name__) - @functools.wraps(func) - def wrapper(do_not_track: bool, *args: P.args, **attrs: P.kwargs) -> t.Any: - if do_not_track: - with analytics.set_bentoml_tracking(): return func(*args, **attrs) - start_time = time.time_ns() - with analytics.set_bentoml_tracking(): - if group.name is None: raise ValueError("group.name should not be None") - event = analytics.OpenllmCliEvent(cmd_group=group.name, cmd_name=command_name) - try: - return_value = func(*args, **attrs) - duration_in_ms = (time.time_ns() - start_time) / 1e6 - event.duration_in_ms = duration_in_ms - analytics.track(event) - return return_value - except Exception as e: - duration_in_ms = (time.time_ns() - start_time) / 1e6 - event.duration_in_ms = duration_in_ms - event.error_type = type(e).__name__ - event.return_code = 2 if isinstance(e, KeyboardInterrupt) else 1 - analytics.track(event) - raise - return wrapper - @staticmethod - def exception_handling(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[P, t.Any]: - command_name = attrs.get("name", func.__name__) - @functools.wraps(func) - def wrapper(*args: P.args, **attrs: P.kwargs) -> t.Any: - try: return func(*args, **attrs) - except OpenLLMException as err: - raise click.ClickException(click.style(f"[{group.name}] '{command_name}' failed: " + err.message, fg="red")) from err - except KeyboardInterrupt: pass - return wrapper - def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: - cmd_name = self.resolve_alias(cmd_name) - if ctx.command.name in _start_mapping: - try: return _start_mapping[ctx.command.name][cmd_name] - except KeyError: - # TODO: support start from a bento - try: - bentoml.get(cmd_name) - raise click.ClickException(f"'openllm start {cmd_name}' is currently disabled for the time being. Please let us know if you need this feature by opening an issue on GitHub.") - except bentoml.exceptions.NotFound: pass - raise click.BadArgumentUsage(f"{cmd_name} is not a valid model identifier supported by OpenLLM.") from None - return super().get_command(ctx, cmd_name) - def list_commands(self, ctx: click.Context) -> list[str]: - if ctx.command.name in {"start", "start-grpc"}: return list(CONFIG_MAPPING.keys()) - return super().list_commands(ctx) - # NOTE: The following overload are ported from click to make sure - # cli.command is correctly typed. See https://github.com/pallets/click/blob/main/src/click/decorators.py#L136 - # + NUMBER_OF_COMMON_PARAMS = 5 # parameters in common_params + 1 faked group option header + + @staticmethod + def common_params(f: t.Callable[P, t.Any]) -> t.Callable[[FC], FC]: + # The following logics is similar to one of BentoMLCommandGroup + @cog.optgroup.group("Global options") + @cog.optgroup.option("-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output.", show_envvar=True) + @cog.optgroup.option("--debug", "--verbose", "debug", envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help="Print out debug logs.", show_envvar=True) + @cog.optgroup.option("--do-not-track", is_flag=True, default=False, envvar=analytics.OPENLLM_DO_NOT_TRACK, help="Do not send usage info", show_envvar=True) + @cog.optgroup.option("--context", "cloud_context", envvar="BENTOCLOUD_CONTEXT", type=click.STRING, default=None, help="BentoCloud context name.", show_envvar=True) + @click.pass_context + @functools.wraps(f) + def wrapper(ctx: click.Context, quiet: bool, debug: bool, cloud_context: str | None, *args: P.args, **attrs: P.kwargs) -> t.Any: + ctx.obj = GlobalOptions(cloud_context=cloud_context) + if quiet: + set_quiet_mode(True) + if debug: logger.warning("'--quiet' passed; ignoring '--verbose/--debug'") + elif debug: set_debug_mode(True) + configure_logging() + return f(*args, **attrs) + + return wrapper + + @staticmethod + def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[t.Concatenate[bool, P], t.Any]: + command_name = attrs.get("name", func.__name__) + + @functools.wraps(func) + def wrapper(do_not_track: bool, *args: P.args, **attrs: P.kwargs) -> t.Any: + if do_not_track: + with analytics.set_bentoml_tracking(): + return func(*args, **attrs) + start_time = time.time_ns() + with analytics.set_bentoml_tracking(): + if group.name is None: raise ValueError("group.name should not be None") + event = analytics.OpenllmCliEvent(cmd_group=group.name, cmd_name=command_name) + try: + return_value = func(*args, **attrs) + duration_in_ms = (time.time_ns() - start_time) / 1e6 + event.duration_in_ms = duration_in_ms + analytics.track(event) + return return_value + except Exception as e: + duration_in_ms = (time.time_ns() - start_time) / 1e6 + event.duration_in_ms = duration_in_ms + event.error_type = type(e).__name__ + event.return_code = 2 if isinstance(e, KeyboardInterrupt) else 1 + analytics.track(event) + raise + + return wrapper + + @staticmethod + def exception_handling(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[P, t.Any]: + command_name = attrs.get("name", func.__name__) + + @functools.wraps(func) + def wrapper(*args: P.args, **attrs: P.kwargs) -> t.Any: + try: + return func(*args, **attrs) + except OpenLLMException as err: + raise click.ClickException(click.style(f"[{group.name}] '{command_name}' failed: " + err.message, fg="red")) from err + except KeyboardInterrupt: + pass + + return wrapper + + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: + cmd_name = self.resolve_alias(cmd_name) + if ctx.command.name in _start_mapping: + try: + return _start_mapping[ctx.command.name][cmd_name] + except KeyError: + # TODO: support start from a bento + try: + bentoml.get(cmd_name) + raise click.ClickException(f"'openllm start {cmd_name}' is currently disabled for the time being. Please let us know if you need this feature by opening an issue on GitHub.") + except bentoml.exceptions.NotFound: + pass + raise click.BadArgumentUsage(f"{cmd_name} is not a valid model identifier supported by OpenLLM.") from None + return super().get_command(ctx, cmd_name) + + def list_commands(self, ctx: click.Context) -> list[str]: + if ctx.command.name in {"start", "start-grpc"}: return list(CONFIG_MAPPING.keys()) + return super().list_commands(ctx) + + # NOTE: The following overload are ported from click to make sure + # cli.command is correctly typed. See https://github.com/pallets/click/blob/main/src/click/decorators.py#L136 + # + # variant: no call, directly as decorator for a function. + @overload + def command(self, name: _AnyCallable) -> click.Command: + ... + + # variant: with positional name and with positional or keyword cls argument: + # @command(namearg, CommandCls, ...) or @command(namearg, cls=CommandCls, ...) + @overload + def command(self, name: str | None, cls: type[CmdType], **attrs: t.Any) -> t.Callable[[_AnyCallable], CmdType]: + ... + + # variant: name omitted, cls _must_ be a keyword argument, @command(cmd=CommandCls, ...) + @overload + def command(self, name: None = None, *, cls: type[CmdType], **attrs: t.Any) -> t.Callable[[_AnyCallable], CmdType]: + ... + + # variant: name omitted, only provide keyword arguments, @command(context_settings={}) + @overload + def command(self, *, cls: type[CmdType], **attrs: t.Any) -> t.Callable[[_AnyCallable], CmdType]: + ... + + # variant: with optional string name, no cls argument provided. + @overload + def command(self, name: t.Optional[str] = ..., cls: None = None, **attrs: t.Any) -> t.Callable[[_AnyCallable], click.Command]: + ... + + def command(self, name: str | None | _AnyCallable = None, cls: type[CmdType] | None = None, *args: t.Any, **attrs: t.Any) -> click.Command | t.Callable[[_AnyCallable], click.Command | CmdType]: + """Override the default 'cli.command' with supports for aliases for given command, and it wraps the implementation with common parameters.""" + if "context_settings" not in attrs: attrs["context_settings"] = {} + if "max_content_width" not in attrs["context_settings"]: attrs["context_settings"]["max_content_width"] = 120 + aliases = attrs.pop("aliases", None) + + def decorator(f: _AnyCallable) -> click.Command: + name = f.__name__.lower() + if name.endswith("_command"): name = name[:-8] + name = name.replace("_", "-") + attrs.setdefault("cls", cls) + attrs.setdefault("help", inspect.getdoc(f)) + attrs.setdefault("name", name) + + # Wrap implementation withc common parameters + wrapped = self.common_params(f) + # Wrap into OpenLLM tracking + wrapped = self.usage_tracking(wrapped, self, **attrs) + # Wrap into exception handling + wrapped = self.exception_handling(wrapped, self, **attrs) + + # move common parameters to end of the parameters list + _memo = getattr(wrapped, "__click_params__", None) + if _memo is None: raise RuntimeError("Click command not register correctly.") + _object_setattr(wrapped, "__click_params__", _memo[-self.NUMBER_OF_COMMON_PARAMS:] + _memo[:-self.NUMBER_OF_COMMON_PARAMS]) + # NOTE: we need to call super of super to avoid conflict with BentoMLCommandGroup command setup + cmd = super(BentoMLCommandGroup, self).command(*args, **attrs)(wrapped) + # NOTE: add aliases to a given commands if it is specified. + if aliases is not None: + if not cmd.name: raise ValueError("name is required when aliases are available.") + self._commands[cmd.name] = aliases + self._aliases.update({alias: cmd.name for alias in aliases}) + return cmd + + return decorator + + if t.TYPE_CHECKING: # variant: no call, directly as decorator for a function. @overload - def command(self, name: _AnyCallable) -> click.Command: ... + def group(self, name: _AnyCallable) -> click.Group: + ... + # variant: with positional name and with positional or keyword cls argument: - # @command(namearg, CommandCls, ...) or @command(namearg, cls=CommandCls, ...) + # @group(namearg, GroupCls, ...) or @group(namearg, cls=GroupCls, ...) @overload - def command(self, name: str | None, cls: type[CmdType], **attrs: t.Any) -> t.Callable[[_AnyCallable], CmdType]: ... - # variant: name omitted, cls _must_ be a keyword argument, @command(cmd=CommandCls, ...) + def group(self, name: str | None, cls: type[GrpType], **attrs: t.Any) -> t.Callable[[_AnyCallable], GrpType]: + ... + + # variant: name omitted, cls _must_ be a keyword argument, @group(cmd=GroupCls, ...) @overload - def command(self, name: None = None, *, cls: type[CmdType], **attrs: t.Any) -> t.Callable[[_AnyCallable], CmdType]: ... - # variant: name omitted, only provide keyword arguments, @command(context_settings={}) - @overload - def command(self, *, cls: type[CmdType], **attrs: t.Any) -> t.Callable[[_AnyCallable], CmdType]: ... + def group(self, name: None = None, *, cls: t.Type[GrpType], **attrs: t.Any) -> t.Callable[[_AnyCallable], GrpType]: + ... + # variant: with optional string name, no cls argument provided. @overload - def command(self, name: t.Optional[str] = ..., cls: None = None, **attrs: t.Any) -> t.Callable[[_AnyCallable], click.Command]: ... - def command(self, name: str | None | _AnyCallable = None, cls: type[CmdType] | None = None, *args: t.Any, **attrs: t.Any) -> click.Command | t.Callable[[_AnyCallable], click.Command | CmdType]: - """Override the default 'cli.command' with supports for aliases for given command, and it wraps the implementation with common parameters.""" - if "context_settings" not in attrs: attrs["context_settings"] = {} - if "max_content_width" not in attrs["context_settings"]: attrs["context_settings"]["max_content_width"] = 120 - aliases = attrs.pop("aliases", None) - def decorator(f: _AnyCallable) -> click.Command: - name = f.__name__.lower() - if name.endswith("_command"): name = name[:-8] - name = name.replace("_", "-") - attrs.setdefault("cls", cls) - attrs.setdefault("help", inspect.getdoc(f)) - attrs.setdefault("name", name) + def group(self, name: str | None = ..., cls: None = None, **attrs: t.Any) -> t.Callable[[_AnyCallable], click.Group]: + ... - # Wrap implementation withc common parameters - wrapped = self.common_params(f) - # Wrap into OpenLLM tracking - wrapped = self.usage_tracking(wrapped, self, **attrs) - # Wrap into exception handling - wrapped = self.exception_handling(wrapped, self, **attrs) - - # move common parameters to end of the parameters list - _memo = getattr(wrapped, "__click_params__", None) - if _memo is None: raise RuntimeError("Click command not register correctly.") - _object_setattr(wrapped, "__click_params__", _memo[-self.NUMBER_OF_COMMON_PARAMS:] + _memo[:-self.NUMBER_OF_COMMON_PARAMS]) - # NOTE: we need to call super of super to avoid conflict with BentoMLCommandGroup command setup - cmd = super(BentoMLCommandGroup, self).command(*args, **attrs)(wrapped) - # NOTE: add aliases to a given commands if it is specified. - if aliases is not None: - if not cmd.name: raise ValueError("name is required when aliases are available.") - self._commands[cmd.name] = aliases - self._aliases.update({alias: cmd.name for alias in aliases}) - return cmd - return decorator - - if t.TYPE_CHECKING: - # variant: no call, directly as decorator for a function. - @overload - def group(self, name: _AnyCallable) -> click.Group:... - # variant: with positional name and with positional or keyword cls argument: - # @group(namearg, GroupCls, ...) or @group(namearg, cls=GroupCls, ...) - @overload - def group(self, name: str | None, cls: type[GrpType], **attrs: t.Any) -> t.Callable[[_AnyCallable], GrpType]: ... - # variant: name omitted, cls _must_ be a keyword argument, @group(cmd=GroupCls, ...) - @overload - def group(self, name: None = None, *, cls: t.Type[GrpType], **attrs: t.Any) -> t.Callable[[_AnyCallable], GrpType]: ... - # variant: with optional string name, no cls argument provided. - @overload - def group(self, name: str | None = ..., cls: None = None, **attrs: t.Any) -> t.Callable[[_AnyCallable], click.Group]: ... - def group(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[_AnyCallable], click.Group]: ... + def group(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[_AnyCallable], click.Group]: + ... @click.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="openllm") @click.version_option(None, "--version", "-v") def cli() -> None: - """\b - ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ - ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ - ██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║ - ██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║ - ╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║ - ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝. + """\b + ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ + ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ + ██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║ + ██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║ + ╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║ + ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝. - \b - An open platform for operating large language models in production. - Fine-tune, serve, deploy, and monitor any LLMs with ease. - """ # noqa: D205 + \b + An open platform for operating large language models in production. + Fine-tune, serve, deploy, and monitor any LLMs with ease. + """ # noqa: D205 @cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="start", aliases=["start-http"]) def start_command() -> None: - """Start any LLM as a REST server. + """Start any LLM as a REST server. - \b - ```bash - $ openllm -- ... - ``` - """ + \b + ```bash + $ openllm -- ... + ``` + """ @cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="start-grpc") def start_grpc_command() -> None: - """Start any LLM as a gRPC server. + """Start any LLM as a gRPC server. - \b - ```bash - $ openllm start-grpc -- ... - ``` - """ + \b + ```bash + $ openllm start-grpc -- ... + ``` + """ -_start_mapping = {"start": {key: start_command_factory(start_command, key, _context_settings=termui.CONTEXT_SETTINGS) for key in CONFIG_MAPPING}, - "start-grpc": {key: start_command_factory(start_grpc_command, key, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=True) for key in CONFIG_MAPPING}} +_start_mapping = {"start": {key: start_command_factory(start_command, key, _context_settings=termui.CONTEXT_SETTINGS) for key in CONFIG_MAPPING}, "start-grpc": {key: start_command_factory(start_grpc_command, key, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=True) for key in CONFIG_MAPPING}} @cli.command(name="import", aliases=["download"]) @model_name_argument @@ -359,347 +398,275 @@ _start_mapping = {"start": {key: start_command_factory(start_command, key, _cont @machine_option @click.option("--implementation", type=click.Choice(["pt", "tf", "flax", "vllm"]), default=None, help="The implementation for saving this LLM.") @serialisation_option -def import_command( - model_name: str, - model_id: str | None, - converter: str | None, - model_version: str | None, - output: LiteralOutput, - runtime: t.Literal["ggml", "transformers"], - machine: bool, - implementation: LiteralRuntime | None, - quantize: t.Literal["int8", "int4", "gptq"] | None, - serialisation_format: t.Literal["safetensors", "legacy"], -) -> bentoml.Model: - """Setup LLM interactively. +def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, runtime: t.Literal["ggml", "transformers"], machine: bool, implementation: LiteralRuntime | None, quantize: t.Literal["int8", "int4", "gptq"] | None, serialisation_format: t.Literal["safetensors", "legacy"],) -> bentoml.Model: + """Setup LLM interactively. - It accepts two positional arguments: `model_name` and `model_id`. The first name determine - the model type to download, and the second one is the optional model id to download. + It accepts two positional arguments: `model_name` and `model_id`. The first name determine + the model type to download, and the second one is the optional model id to download. - \b - This `model_id` can be either pretrained model id that you can get from HuggingFace Hub, or - a custom model path from your custom pretrained model. Note that the custom model path should - contain all files required to construct `transformers.PretrainedConfig`, `transformers.PreTrainedModel` - and `transformers.PreTrainedTokenizer` objects. + \b + This `model_id` can be either pretrained model id that you can get from HuggingFace Hub, or + a custom model path from your custom pretrained model. Note that the custom model path should + contain all files required to construct `transformers.PretrainedConfig`, `transformers.PreTrainedModel` + and `transformers.PreTrainedTokenizer` objects. - \b - Note: This is useful for development and setup for fine-tune. - This will be automatically called when `ensure_available=True` in `openllm.LLM.for_model` + \b + Note: This is useful for development and setup for fine-tune. + This will be automatically called when `ensure_available=True` in `openllm.LLM.for_model` - \b - ``--model-version`` is an optional option to save the model. Note that - this is recommended when the model_id is a custom path. Usually, if you are only using pretrained - model from HuggingFace Hub, you don't need to specify this. If this is not specified, we will calculate - the hash from the last modified time from this custom path + \b + ``--model-version`` is an optional option to save the model. Note that + this is recommended when the model_id is a custom path. Usually, if you are only using pretrained + model from HuggingFace Hub, you don't need to specify this. If this is not specified, we will calculate + the hash from the last modified time from this custom path - \b - ```bash - $ openllm download opt facebook/opt-2.7b - ``` + \b + ```bash + $ openllm download opt facebook/opt-2.7b + ``` - \b - > If ``quantize`` is passed, the model weights will be saved as quantized weights. You should - > only use this option if you want the weight to be quantized by default. Note that OpenLLM also - > support on-demand quantisation during initial startup. + \b + > If ``quantize`` is passed, the model weights will be saved as quantized weights. You should + > only use this option if you want the weight to be quantized by default. Note that OpenLLM also + > support on-demand quantisation during initial startup. - \b - ## Conversion strategies [EXPERIMENTAL] + \b + ## Conversion strategies [EXPERIMENTAL] - \b - Some models will include built-in conversion strategies for specific weights format. - It will be determined via the `CONVERTER` environment variable. Note that this envvar should only be use provisionally as it is not RECOMMENDED to export this - and save to a ``.env`` file. + \b + Some models will include built-in conversion strategies for specific weights format. + It will be determined via the `CONVERTER` environment variable. Note that this envvar should only be use provisionally as it is not RECOMMENDED to export this + and save to a ``.env`` file. - The conversion strategies will have the following format and will be determined per architecture implementation: - - + The conversion strategies will have the following format and will be determined per architecture implementation: + - - \b - For example: the below convert LlaMA-2 model format to hf: + \b + For example: the below convert LlaMA-2 model format to hf: - \b - ```bash - $ CONVERTER=llama2-hf openllm import llama /path/to/llama-2 - ``` + \b + ```bash + $ CONVERTER=llama2-hf openllm import llama /path/to/llama-2 + ``` - > **Note**: This behaviour will override ``--runtime``. Therefore make sure that the LLM contains correct conversion strategies to both GGML and HF. - """ - llm_config = AutoConfig.for_model(model_name) - env = EnvVarMixin(model_name, llm_config.default_implementation(), model_id=model_id, runtime=runtime, quantize=quantize) - impl: LiteralRuntime = first_not_none(implementation, default=env.framework_value) - llm = infer_auto_class(impl).for_model(model_name, llm_config=llm_config, model_version=model_version, ensure_available=False, serialisation=serialisation_format) - _previously_saved = False - try: - _ref = serialisation.get(llm) - _previously_saved = True - except bentoml.exceptions.NotFound: - if not machine and output == "pretty": - msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store. Saving to BENTOML_HOME{' (path=' + os.getenv('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..." - termui.echo(msg, fg="yellow", nl=True) - _ref = serialisation.get(llm, auto_import=True) - if impl == "pt" and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache() - if machine: return _ref - elif output == "pretty": - if _previously_saved: termui.echo(f"{model_name} with 'model_id={model_id}' is already setup for framework '{impl}': {_ref.tag!s}", nl=True, fg="yellow") - else: termui.echo(f"Saved model: {_ref.tag}") - elif output == "json": termui.echo(orjson.dumps({"previously_setup": _previously_saved, "framework": impl, "tag": str(_ref.tag)}, option=orjson.OPT_INDENT_2).decode()) - else: termui.echo(_ref.tag) - return _ref + > **Note**: This behaviour will override ``--runtime``. Therefore make sure that the LLM contains correct conversion strategies to both GGML and HF. + """ + llm_config = AutoConfig.for_model(model_name) + env = EnvVarMixin(model_name, llm_config.default_implementation(), model_id=model_id, runtime=runtime, quantize=quantize) + impl: LiteralRuntime = first_not_none(implementation, default=env.framework_value) + llm = infer_auto_class(impl).for_model(model_name, llm_config=llm_config, model_version=model_version, ensure_available=False, serialisation=serialisation_format) + _previously_saved = False + try: + _ref = serialisation.get(llm) + _previously_saved = True + except bentoml.exceptions.NotFound: + if not machine and output == "pretty": + msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store. Saving to BENTOML_HOME{' (path=' + os.getenv('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..." + termui.echo(msg, fg="yellow", nl=True) + _ref = serialisation.get(llm, auto_import=True) + if impl == "pt" and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache() + if machine: return _ref + elif output == "pretty": + if _previously_saved: termui.echo(f"{model_name} with 'model_id={model_id}' is already setup for framework '{impl}': {_ref.tag!s}", nl=True, fg="yellow") + else: termui.echo(f"Saved model: {_ref.tag}") + elif output == "json": termui.echo(orjson.dumps({"previously_setup": _previously_saved, "framework": impl, "tag": str(_ref.tag)}, option=orjson.OPT_INDENT_2).decode()) + else: termui.echo(_ref.tag) + return _ref def _start( - model_name: str, - /, *, - model_id: str | None = None, - timeout: int = 30, - workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None, - device: tuple[str, ...] | t.Literal["all"] | None = None, - quantize: t.Literal["int8", "int4", "gptq"] | None = None, - bettertransformer: bool | None = None, - runtime: t.Literal["ggml", "transformers"] = "transformers", - fast: bool = False, - adapter_map: dict[t.LiteralString, str | None] | None = None, - framework: LiteralRuntime | None = None, - additional_args: ListStr | None = None, - _serve_grpc: bool = False, - __test__: bool = False, - **_: t.Any, + model_name: str, /, *, model_id: str | None = None, timeout: int = 30, workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None, device: tuple[str, ...] | t.Literal["all"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", + fast: bool = False, adapter_map: dict[t.LiteralString, str | None] | None = None, framework: LiteralRuntime | None = None, additional_args: ListStr | None = None, _serve_grpc: bool = False, __test__: bool = False, **_: t.Any, ) -> LLMConfig | subprocess.Popen[bytes]: - """Python API to start a LLM server. These provides one-to-one mapping to CLI arguments. + """Python API to start a LLM server. These provides one-to-one mapping to CLI arguments. - For all additional arguments, pass it as string to ``additional_args``. For example, if you want to - pass ``--port 5001``, you can pass ``additional_args=["--port", "5001"]`` + For all additional arguments, pass it as string to ``additional_args``. For example, if you want to + pass ``--port 5001``, you can pass ``additional_args=["--port", "5001"]`` - > **Note**: This will create a blocking process, so if you use this API, you can create a running sub thread - > to start the server instead of blocking the main thread. + > **Note**: This will create a blocking process, so if you use this API, you can create a running sub thread + > to start the server instead of blocking the main thread. - ``openllm.start`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI interaction. + ``openllm.start`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI interaction. - > **Note**: ``quantize`` and ``bettertransformer`` are mutually exclusive. + > **Note**: ``quantize`` and ``bettertransformer`` are mutually exclusive. - Args: - model_name: The model name to start this LLM - model_id: Optional model id for this given LLM - timeout: The server timeout - workers_per_resource: Number of workers per resource assigned. - See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy - for more information. By default, this is set to 1. + Args: + model_name: The model name to start this LLM + model_id: Optional model id for this given LLM + timeout: The server timeout + workers_per_resource: Number of workers per resource assigned. + See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy + for more information. By default, this is set to 1. - > **Note**: ``--workers-per-resource`` will also accept the following strategies: + > **Note**: ``--workers-per-resource`` will also accept the following strategies: - > - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. + > - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. - > - ``conserved``: Thjis will determine the number of available GPU resources, and only assign - one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is - equivalent to ``--workers-per-resource 0.25``. - device: Assign GPU devices (if available) to this LLM. By default, this is set to ``None``. It also accepts 'all' - argument to assign all available GPUs to this LLM. - quantize: Quantize the model weights. This is only applicable for PyTorch models. - Possible quantisation strategies: - - int8: Quantize the model with 8bit (bitsandbytes required) - - int4: Quantize the model with 4bit (bitsandbytes required) - - gptq: Quantize the model with GPTQ (auto-gptq required) - bettertransformer: Convert given model to FastTransformer with PyTorch. - runtime: The runtime to use for this LLM. By default, this is set to ``transformers``. In the future, this will include supports for GGML. - fast: Enable fast mode. This will skip downloading models, and will raise errors if given model_id does not exists under local store. - adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``. - framework: The framework to use for this LLM. By default, this is set to ``pt``. - additional_args: Additional arguments to pass to ``openllm start``. - """ - fast = os.getenv("OPENLLM_FAST", str(fast)).upper() in ENV_VARS_TRUE_VALUES - llm_config = AutoConfig.for_model(model_name) - _ModelEnv = EnvVarMixin(model_name, first_not_none(framework, default=llm_config.default_implementation()), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime) - os.environ[_ModelEnv.framework] = _ModelEnv.framework_value + > - ``conserved``: Thjis will determine the number of available GPU resources, and only assign + one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is + equivalent to ``--workers-per-resource 0.25``. + device: Assign GPU devices (if available) to this LLM. By default, this is set to ``None``. It also accepts 'all' + argument to assign all available GPUs to this LLM. + quantize: Quantize the model weights. This is only applicable for PyTorch models. + Possible quantisation strategies: + - int8: Quantize the model with 8bit (bitsandbytes required) + - int4: Quantize the model with 4bit (bitsandbytes required) + - gptq: Quantize the model with GPTQ (auto-gptq required) + bettertransformer: Convert given model to FastTransformer with PyTorch. + runtime: The runtime to use for this LLM. By default, this is set to ``transformers``. In the future, this will include supports for GGML. + fast: Enable fast mode. This will skip downloading models, and will raise errors if given model_id does not exists under local store. + adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``. + framework: The framework to use for this LLM. By default, this is set to ``pt``. + additional_args: Additional arguments to pass to ``openllm start``. + """ + fast = os.getenv("OPENLLM_FAST", str(fast)).upper() in ENV_VARS_TRUE_VALUES + llm_config = AutoConfig.for_model(model_name) + _ModelEnv = EnvVarMixin(model_name, first_not_none(framework, default=llm_config.default_implementation()), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime) + os.environ[_ModelEnv.framework] = _ModelEnv.framework_value - args: ListStr = ["--runtime", runtime] - if model_id: args.extend(["--model-id", model_id]) - if timeout: args.extend(["--server-timeout", str(timeout)]) - if workers_per_resource: args.extend(["--workers-per-resource", str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource]) - if device and not os.getenv("CUDA_VISIBLE_DEVICES"): args.extend(["--device", ",".join(device)]) - if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.") - if quantize: args.extend(["--quantize", str(quantize)]) - elif bettertransformer: args.append("--bettertransformer") - if fast: args.append("--fast") - if adapter_map: args.extend(list(itertools.chain.from_iterable([["--adapter-id", f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()]))) - if additional_args: args.extend(additional_args) - if __test__: args.append("--return-process") + args: ListStr = ["--runtime", runtime] + if model_id: args.extend(["--model-id", model_id]) + if timeout: args.extend(["--server-timeout", str(timeout)]) + if workers_per_resource: args.extend(["--workers-per-resource", str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource]) + if device and not os.getenv("CUDA_VISIBLE_DEVICES"): args.extend(["--device", ",".join(device)]) + if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.") + if quantize: args.extend(["--quantize", str(quantize)]) + elif bettertransformer: args.append("--bettertransformer") + if fast: args.append("--fast") + if adapter_map: args.extend(list(itertools.chain.from_iterable([["--adapter-id", f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()]))) + if additional_args: args.extend(additional_args) + if __test__: args.append("--return-process") - return start_command_factory(start_command if not _serve_grpc else start_grpc_command, model_name, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=_serve_grpc).main( args=args if len(args) > 0 else None, standalone_mode=False) + return start_command_factory(start_command if not _serve_grpc else start_grpc_command, model_name, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=_serve_grpc).main(args=args if len(args) > 0 else None, standalone_mode=False) @inject def _build( - model_name: str, - /, - *, - model_id: str | None = None, - model_version: str | None = None, - quantize: t.Literal["int8", "int4", "gptq"] | None = None, - bettertransformer: bool | None = None, - adapter_map: dict[str, str | None] | None = None, - build_ctx: str | None = None, - enable_features: tuple[str, ...] | None = None, - workers_per_resource: int | float | None = None, - runtime: t.Literal["ggml", "transformers"] = "transformers", - dockerfile_template: str | None = None, - overwrite: bool = False, - container_registry: LiteralContainerRegistry | None = None, - container_version_strategy: LiteralContainerVersionStrategy | None = None, - push: bool = False, - containerize: bool = False, - serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", - additional_args: list[str] | None = None, - bento_store: BentoStore = Provide[BentoMLContainer.bento_store], + model_name: str, /, *, model_id: str | None = None, model_version: str | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, adapter_map: dict[str, str | None] | None = None, build_ctx: str | None = None, enable_features: tuple[str, ...] | None = None, workers_per_resource: int | float | None = None, runtime: t.Literal[ + "ggml", "transformers"] = "transformers", dockerfile_template: str | None = None, overwrite: bool = False, container_registry: LiteralContainerRegistry | None = None, container_version_strategy: LiteralContainerVersionStrategy | None = None, push: bool = False, containerize: bool = False, serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", + additional_args: list[str] | None = None, bento_store: BentoStore = Provide[BentoMLContainer.bento_store], ) -> bentoml.Bento: - """Package a LLM into a Bento. + """Package a LLM into a Bento. - The LLM will be built into a BentoService with the following structure: - if ``quantize`` is passed, it will instruct the model to be quantized dynamically during serving time. - if ``bettertransformer`` is passed, it will instruct the model to apply FasterTransformer during serving time. + The LLM will be built into a BentoService with the following structure: + if ``quantize`` is passed, it will instruct the model to be quantized dynamically during serving time. + if ``bettertransformer`` is passed, it will instruct the model to apply FasterTransformer during serving time. - ``openllm.build`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as ``openllm build`` CLI. + ``openllm.build`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as ``openllm build`` CLI. - > **Note**: ``quantize`` and ``bettertransformer`` are mutually exclusive. + > **Note**: ``quantize`` and ``bettertransformer`` are mutually exclusive. - Args: - model_name: The model name to start this LLM - model_id: Optional model id for this given LLM - model_version: Optional model version for this given LLM - quantize: Quantize the model weights. This is only applicable for PyTorch models. - Possible quantisation strategies: - - int8: Quantize the model with 8bit (bitsandbytes required) - - int4: Quantize the model with 4bit (bitsandbytes required) - - gptq: Quantize the model with GPTQ (auto-gptq required) - bettertransformer: Convert given model to FastTransformer with PyTorch. - adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``. - build_ctx: The build context to use for building BentoLLM. By default, it sets to current directory. - enable_features: Additional OpenLLM features to be included with this BentoLLM. - workers_per_resource: Number of workers per resource assigned. - See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy - for more information. By default, this is set to 1. + Args: + model_name: The model name to start this LLM + model_id: Optional model id for this given LLM + model_version: Optional model version for this given LLM + quantize: Quantize the model weights. This is only applicable for PyTorch models. + Possible quantisation strategies: + - int8: Quantize the model with 8bit (bitsandbytes required) + - int4: Quantize the model with 4bit (bitsandbytes required) + - gptq: Quantize the model with GPTQ (auto-gptq required) + bettertransformer: Convert given model to FastTransformer with PyTorch. + adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``. + build_ctx: The build context to use for building BentoLLM. By default, it sets to current directory. + enable_features: Additional OpenLLM features to be included with this BentoLLM. + workers_per_resource: Number of workers per resource assigned. + See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy + for more information. By default, this is set to 1. - > **Note**: ``--workers-per-resource`` will also accept the following strategies: + > **Note**: ``--workers-per-resource`` will also accept the following strategies: - > - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. + > - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. - > - ``conserved``: This will determine the number of available GPU resources, and only assign - one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is - equivalent to ``--workers-per-resource 0.25``. - runtime: The runtime to use for this LLM. By default, this is set to ``transformers``. In the future, this will include supports for GGML. - dockerfile_template: The dockerfile template to use for building BentoLLM. See - https://docs.bentoml.com/en/latest/guides/containerization.html#dockerfile-template. - overwrite: Whether to overwrite the existing BentoLLM. By default, this is set to ``False``. - push: Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first. - containerize: Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'. - Note that 'containerize' and 'push' are mutually exclusive - container_registry: Container registry to choose the base OpenLLM container image to build from. Default to ECR. - container_version_strategy: The container version strategy. Default to the latest release of OpenLLM. - serialisation_format: Serialisation for saving models. Default to 'safetensors', which is equivalent to `safe_serialization=True` - additional_args: Additional arguments to pass to ``openllm build``. - bento_store: Optional BentoStore for saving this BentoLLM. Default to the default BentoML local store. + > - ``conserved``: This will determine the number of available GPU resources, and only assign + one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is + equivalent to ``--workers-per-resource 0.25``. + runtime: The runtime to use for this LLM. By default, this is set to ``transformers``. In the future, this will include supports for GGML. + dockerfile_template: The dockerfile template to use for building BentoLLM. See + https://docs.bentoml.com/en/latest/guides/containerization.html#dockerfile-template. + overwrite: Whether to overwrite the existing BentoLLM. By default, this is set to ``False``. + push: Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first. + containerize: Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'. + Note that 'containerize' and 'push' are mutually exclusive + container_registry: Container registry to choose the base OpenLLM container image to build from. Default to ECR. + container_version_strategy: The container version strategy. Default to the latest release of OpenLLM. + serialisation_format: Serialisation for saving models. Default to 'safetensors', which is equivalent to `safe_serialization=True` + additional_args: Additional arguments to pass to ``openllm build``. + bento_store: Optional BentoStore for saving this BentoLLM. Default to the default BentoML local store. - Returns: - ``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud. - If 'format="container"', then it returns the default 'container_name:container_tag' - """ - args: ListStr = [ - sys.executable, - "-m", - "openllm", - "build", - model_name, - "--machine", - "--runtime", - runtime, - "--serialisation", - serialisation_format, - ] - if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.") - if quantize: args.extend(["--quantize", quantize]) - if bettertransformer: args.append("--bettertransformer") - if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.") - if push: args.extend(["--push"]) - if containerize: args.extend(["--containerize"]) - if model_id: args.extend(["--model-id", model_id]) - if build_ctx: args.extend(["--build-ctx", build_ctx]) - if enable_features: args.extend([f"--enable-features={f}" for f in enable_features]) - if workers_per_resource: args.extend(["--workers-per-resource", str(workers_per_resource)]) - if overwrite: args.append("--overwrite") - if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) - if model_version: args.extend(["--model-version", model_version]) - if dockerfile_template: args.extend(["--dockerfile-template", dockerfile_template]) - if container_registry is None: container_registry = "ecr" - if container_version_strategy is None: container_version_strategy = "release" - args.extend(["--container-registry", container_registry, "--container-version-strategy", container_version_strategy]) - if additional_args: args.extend(additional_args) + Returns: + ``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud. + If 'format="container"', then it returns the default 'container_name:container_tag' + """ + args: ListStr = [sys.executable, "-m", "openllm", "build", model_name, "--machine", "--runtime", runtime, "--serialisation", serialisation_format,] + if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.") + if quantize: args.extend(["--quantize", quantize]) + if bettertransformer: args.append("--bettertransformer") + if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.") + if push: args.extend(["--push"]) + if containerize: args.extend(["--containerize"]) + if model_id: args.extend(["--model-id", model_id]) + if build_ctx: args.extend(["--build-ctx", build_ctx]) + if enable_features: args.extend([f"--enable-features={f}" for f in enable_features]) + if workers_per_resource: args.extend(["--workers-per-resource", str(workers_per_resource)]) + if overwrite: args.append("--overwrite") + if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) + if model_version: args.extend(["--model-version", model_version]) + if dockerfile_template: args.extend(["--dockerfile-template", dockerfile_template]) + if container_registry is None: container_registry = "ecr" + if container_version_strategy is None: container_version_strategy = "release" + args.extend(["--container-registry", container_registry, "--container-version-strategy", container_version_strategy]) + if additional_args: args.extend(additional_args) - try: output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd()) - except subprocess.CalledProcessError as e: - logger.error("Exception caught while building %s", model_name, exc_info=e) - if e.stderr: raise OpenLLMException(e.stderr.decode("utf-8")) from None - raise OpenLLMException(str(e)) from None - pattern = r"^__tag__:[^:\n]+:[^:\n]+" - matched = re.search(pattern, output.decode("utf-8").strip(), re.MULTILINE) - if matched is None: raise ValueError(f"Failed to find tag from output: {output!s}") - return bentoml.get(matched.group(0).partition(":")[-1], _bento_store=bento_store) + try: + output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd()) + except subprocess.CalledProcessError as e: + logger.error("Exception caught while building %s", model_name, exc_info=e) + if e.stderr: raise OpenLLMException(e.stderr.decode("utf-8")) from None + raise OpenLLMException(str(e)) from None + pattern = r"^__tag__:[^:\n]+:[^:\n]+" + matched = re.search(pattern, output.decode("utf-8").strip(), re.MULTILINE) + if matched is None: raise ValueError(f"Failed to find tag from output: {output!s}") + return bentoml.get(matched.group(0).partition(":")[-1], _bento_store=bento_store) def _import_model( - model_name: str, - /, - *, - model_id: str | None = None, - model_version: str | None = None, - runtime: t.Literal["ggml", "transformers"] = "transformers", - implementation: LiteralRuntime = "pt", - quantize: t.Literal["int8", "int4", "gptq"] | None = None, - serialisation_format: t.Literal["legacy", "safetensors"] = "safetensors", - additional_args: t.Sequence[str] | None = None, + model_name: str, /, *, model_id: str | None = None, model_version: str | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", implementation: LiteralRuntime = "pt", quantize: t.Literal["int8", "int4", "gptq"] | None = None, serialisation_format: t.Literal["legacy", "safetensors"] = "safetensors", additional_args: t.Sequence[str] | None = None, ) -> bentoml.Model: - """Import a LLM into local store. + """Import a LLM into local store. - > **Note**: If ``quantize`` is passed, the model weights will be saved as quantized weights. You should - > only use this option if you want the weight to be quantized by default. Note that OpenLLM also - > support on-demand quantisation during initial startup. + > **Note**: If ``quantize`` is passed, the model weights will be saved as quantized weights. You should + > only use this option if you want the weight to be quantized by default. Note that OpenLLM also + > support on-demand quantisation during initial startup. - ``openllm.download`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI ``openllm import``. + ``openllm.download`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI ``openllm import``. - > **Note**: ``openllm.start`` will automatically invoke ``openllm.download`` under the hood. + > **Note**: ``openllm.start`` will automatically invoke ``openllm.download`` under the hood. - Args: - model_name: The model name to start this LLM - model_id: Optional model id for this given LLM - model_version: Optional model version for this given LLM - runtime: The runtime to use for this LLM. By default, this is set to ``transformers``. In the future, this will include supports for GGML. - implementation: The implementation to use for this LLM. By default, this is set to ``pt``. - quantize: Quantize the model weights. This is only applicable for PyTorch models. - Possible quantisation strategies: - - int8: Quantize the model with 8bit (bitsandbytes required) - - int4: Quantize the model with 4bit (bitsandbytes required) - - gptq: Quantize the model with GPTQ (auto-gptq required) - serialisation_format: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors. - Default behaviour is similar to ``safe_serialization=False``. - additional_args: Additional arguments to pass to ``openllm import``. - - Returns: - ``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud. - """ - args = [ - model_name, - "--runtime", - runtime, - "--implementation", - implementation, - "--machine", - "--serialisation", - serialisation_format, - ] - if model_id is not None: args.append(model_id) - if model_version is not None: args.extend(["--model-version", str(model_version)]) - if additional_args is not None: args.extend(additional_args) - if quantize is not None: args.extend(["--quantize", quantize]) - return import_command.main(args=args, standalone_mode=False) + Args: + model_name: The model name to start this LLM + model_id: Optional model id for this given LLM + model_version: Optional model version for this given LLM + runtime: The runtime to use for this LLM. By default, this is set to ``transformers``. In the future, this will include supports for GGML. + implementation: The implementation to use for this LLM. By default, this is set to ``pt``. + quantize: Quantize the model weights. This is only applicable for PyTorch models. + Possible quantisation strategies: + - int8: Quantize the model with 8bit (bitsandbytes required) + - int4: Quantize the model with 4bit (bitsandbytes required) + - gptq: Quantize the model with GPTQ (auto-gptq required) + serialisation_format: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors. + Default behaviour is similar to ``safe_serialization=False``. + additional_args: Additional arguments to pass to ``openllm import``. + Returns: + ``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud. + """ + args = [model_name, "--runtime", runtime, "--implementation", implementation, "--machine", "--serialisation", serialisation_format,] + if model_id is not None: args.append(model_id) + if model_version is not None: args.extend(["--model-version", str(model_version)]) + if additional_args is not None: args.extend(additional_args) + if quantize is not None: args.extend(["--quantize", quantize]) + return import_command.main(args=args, standalone_mode=False) def _list_models() -> DictStrAny: - """List all available models within the local store.""" - return models_command.main(args=["-o", "json", "--show-available", "--machine"], standalone_mode=False) + """List all available models within the local store.""" + return models_command.main(args=["-o", "json", "--show-available", "--machine"], standalone_mode=False) start, start_grpc, build, import_model, list_models = codegen.gen_sdk(_start, _serve_grpc=False), codegen.gen_sdk(_start, _serve_grpc=True), codegen.gen_sdk(_build), codegen.gen_sdk(_import_model), codegen.gen_sdk(_list_models) @@ -714,10 +681,8 @@ start, start_grpc, build, import_model, list_models = codegen.gen_sdk(_start, _s @quantize_option(factory=cog.optgroup, build=True) @bettertransformer_option(factory=cog.optgroup) @click.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers.") -@click.option("--enable-features", multiple=True, nargs=1, metavar="FEATURE[,FEATURE]", - help="Enable additional features for building this LLM Bento. Available: {}".format(", ".join(OPTIONAL_DEPENDENCIES))) -@click.option("--adapter-id", default=None, multiple=True, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]", - help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.") +@click.option("--enable-features", multiple=True, nargs=1, metavar="FEATURE[,FEATURE]", help="Enable additional features for building this LLM Bento. Available: {}".format(", ".join(OPTIONAL_DEPENDENCIES))) +@click.option("--adapter-id", default=None, multiple=True, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]", help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.") @click.option("--build-ctx", help="Build context. This is required if --adapter-id uses relative path", default=None) @model_version_option @click.option("--dockerfile-template", default=None, type=click.File(), help="Optional custom dockerfile template to be used with this BentoLLM.") @@ -730,142 +695,110 @@ start, start_grpc, build, import_model, list_models = codegen.gen_sdk(_start, _s @cog.optgroup.option("--push", default=False, is_flag=True, type=click.BOOL, help="Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.") @click.pass_context def build_command( - ctx: click.Context, /, - model_name: str, - model_id: str | None, - overwrite: bool, - output: LiteralOutput, - runtime: t.Literal["ggml", "transformers"], - quantize: t.Literal["int8", "int4", "gptq"] | None, - enable_features: tuple[str, ...] | None, - bettertransformer: bool | None, - workers_per_resource: float | None, - adapter_id: tuple[str, ...], - build_ctx: str | None, - machine: bool, - model_version: str | None, - dockerfile_template: t.TextIO | None, - containerize: bool, - push: bool, - serialisation_format: t.Literal["safetensors", "legacy"], - fast: bool, - container_registry: LiteralContainerRegistry, - container_version_strategy: LiteralContainerVersionStrategy, - **attrs: t.Any, + ctx: click.Context, /, model_name: str, model_id: str | None, overwrite: bool, output: LiteralOutput, runtime: t.Literal["ggml", "transformers"], quantize: t.Literal["int8", "int4", "gptq"] | None, enable_features: tuple[str, ...] | None, bettertransformer: bool | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, machine: bool, + model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool, push: bool, serialisation_format: t.Literal["safetensors", "legacy"], fast: bool, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy, **attrs: t.Any, ) -> bentoml.Bento: - """Package a given models into a Bento. + """Package a given models into a Bento. - \b - ```bash - $ openllm build flan-t5 --model-id google/flan-t5-large - ``` + \b + ```bash + $ openllm build flan-t5 --model-id google/flan-t5-large + ``` - \b - > NOTE: To run a container built from this Bento with GPU support, make sure - > to have https://github.com/NVIDIA/nvidia-container-toolkit install locally. - """ - if machine: output = "porcelain" - if enable_features: enable_features = tuple(itertools.chain.from_iterable((s.split(",") for s in enable_features))) + \b + > NOTE: To run a container built from this Bento with GPU support, make sure + > to have https://github.com/NVIDIA/nvidia-container-toolkit install locally. + """ + if machine: output = "porcelain" + if enable_features: enable_features = tuple(itertools.chain.from_iterable((s.split(",") for s in enable_features))) - _previously_built = False + _previously_built = False - llm_config = AutoConfig.for_model(model_name) - env = EnvVarMixin(model_name, llm_config.default_implementation(), model_id=model_id, quantize=quantize, bettertransformer=bettertransformer, runtime=runtime) + llm_config = AutoConfig.for_model(model_name) + env = EnvVarMixin(model_name, llm_config.default_implementation(), model_id=model_id, quantize=quantize, bettertransformer=bettertransformer, runtime=runtime) - # NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError - # during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path + # NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError + # during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path + try: + os.environ.update({"OPENLLM_MODEL": inflection.underscore(model_name), env.runtime: str(env.runtime_value), "OPENLLM_SERIALIZATION": serialisation_format}) + if env.model_id_value: os.environ[env.model_id] = str(env.model_id_value) + if env.quantize_value: os.environ[env.quantize] = str(env.quantize_value) + if env.bettertransformer_value: os.environ[env.bettertransformer] = str(env.bettertransformer_value) + + llm = infer_auto_class(env.framework_value).for_model(model_name, llm_config=llm_config, ensure_available=not fast, model_version=model_version, serialisation=serialisation_format, **attrs) + + labels = dict(llm.identifying_params) + labels.update({"_type": llm.llm_type, "_framework": env.framework_value}) + workers_per_resource = first_not_none(workers_per_resource, default=llm_config["workers_per_resource"]) + + with fs.open_fs(f"temp://llm_{llm_config['model_name']}") as llm_fs: + dockerfile_template_path = None + if dockerfile_template: + with dockerfile_template: + llm_fs.writetext("Dockerfile.template", dockerfile_template.read()) + dockerfile_template_path = llm_fs.getsyspath("/Dockerfile.template") + + adapter_map: dict[str, str | None] | None = None + if adapter_id: + if not build_ctx: ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.") + adapter_map = {} + for v in adapter_id: + _adapter_id, *adapter_name = v.rsplit(":", maxsplit=1) + name = adapter_name[0] if len(adapter_name) > 0 else None + try: + resolve_user_filepath(_adapter_id, build_ctx) + src_folder_name = os.path.basename(_adapter_id) + src_fs = fs.open_fs(build_ctx) + llm_fs.makedir(src_folder_name, recreate=True) + fs.copy.copy_dir(src_fs, _adapter_id, llm_fs, src_folder_name) + adapter_map[src_folder_name] = name + # this is the remote adapter, then just added back + # note that there is a drawback here. If the path of the local adapter + # path have the same name as the remote, then we currently don't support + # that edge case. + except FileNotFoundError: + adapter_map[_adapter_id] = name + os.environ["OPENLLM_ADAPTER_MAP"] = orjson.dumps(adapter_map).decode() + bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}".lower().strip()) + try: + bento = bentoml.get(bento_tag) + if overwrite: + if output == "pretty": termui.echo(f"Overwriting existing Bento {bento_tag}", fg="yellow") + bentoml.delete(bento_tag) + raise bentoml.exceptions.NotFound(f"Rebuilding existing Bento {bento_tag}") from None + _previously_built = True + except bentoml.exceptions.NotFound: + bento = bundle.create_bento(bento_tag, llm_fs, llm, workers_per_resource=workers_per_resource, adapter_map=adapter_map, quantize=quantize, bettertransformer=bettertransformer, extra_dependencies=enable_features, dockerfile_template=dockerfile_template_path, runtime=runtime, container_registry=container_registry, container_version_strategy=container_version_strategy,) + except Exception as err: + raise err from None + + if machine: termui.echo(f"__tag__:{bento.tag}", fg="white") + elif output == "pretty": + if not get_quiet_mode() and (not push or not containerize): + termui.echo("\n" + OPENLLM_FIGLET, fg="white") + if not _previously_built: termui.echo(f"Successfully built {bento}.", fg="green") + elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg="yellow") + termui.echo( + "📖 Next steps:\n\n" + "* Push to BentoCloud with 'bentoml push':\n" + f" $ bentoml push {bento.tag}\n\n" + "* Containerize your Bento with 'bentoml containerize':\n" + f" $ bentoml containerize {bento.tag} --opt progress=plain" + "\n\n" + " Tip: To enable additional BentoML features for 'containerize', " + "use '--enable-features=FEATURE[,FEATURE]' " + + "[see 'bentoml containerize -h' for more advanced usage]\n", fg="blue", + ) + elif output == "json": + termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode()) + else: + termui.echo(bento.tag) + + if push: BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context) + elif containerize: + backend = t.cast("DefaultBuilder", os.getenv("BENTOML_CONTAINERIZE_BACKEND", "docker")) try: - os.environ.update({"OPENLLM_MODEL": inflection.underscore(model_name), env.runtime: str(env.runtime_value), "OPENLLM_SERIALIZATION": serialisation_format}) - if env.model_id_value: os.environ[env.model_id] = str(env.model_id_value) - if env.quantize_value: os.environ[env.quantize] = str(env.quantize_value) - if env.bettertransformer_value: os.environ[env.bettertransformer] = str(env.bettertransformer_value) - - llm = infer_auto_class(env.framework_value).for_model(model_name, llm_config=llm_config, ensure_available=not fast, model_version=model_version, serialisation=serialisation_format, **attrs) - - labels = dict(llm.identifying_params) - labels.update({"_type": llm.llm_type, "_framework": env.framework_value}) - workers_per_resource = first_not_none(workers_per_resource, default=llm_config["workers_per_resource"]) - - with fs.open_fs(f"temp://llm_{llm_config['model_name']}") as llm_fs: - dockerfile_template_path = None - if dockerfile_template: - with dockerfile_template: llm_fs.writetext("Dockerfile.template", dockerfile_template.read()) - dockerfile_template_path = llm_fs.getsyspath("/Dockerfile.template") - - adapter_map: dict[str, str | None] | None = None - if adapter_id: - if not build_ctx: ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.") - adapter_map = {} - for v in adapter_id: - _adapter_id, *adapter_name = v.rsplit(":", maxsplit=1) - name = adapter_name[0] if len(adapter_name) > 0 else None - try: - resolve_user_filepath(_adapter_id, build_ctx) - src_folder_name = os.path.basename(_adapter_id) - src_fs = fs.open_fs(build_ctx) - llm_fs.makedir(src_folder_name, recreate=True) - fs.copy.copy_dir(src_fs, _adapter_id, llm_fs, src_folder_name) - adapter_map[src_folder_name] = name - # this is the remote adapter, then just added back - # note that there is a drawback here. If the path of the local adapter - # path have the same name as the remote, then we currently don't support - # that edge case. - except FileNotFoundError: adapter_map[_adapter_id] = name - os.environ["OPENLLM_ADAPTER_MAP"] = orjson.dumps(adapter_map).decode() - bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}".lower().strip()) - try: - bento = bentoml.get(bento_tag) - if overwrite: - if output == "pretty": termui.echo(f"Overwriting existing Bento {bento_tag}", fg="yellow") - bentoml.delete(bento_tag) - raise bentoml.exceptions.NotFound(f"Rebuilding existing Bento {bento_tag}") from None - _previously_built = True - except bentoml.exceptions.NotFound: - bento = bundle.create_bento( - bento_tag, - llm_fs, - llm, - workers_per_resource=workers_per_resource, - adapter_map=adapter_map, - quantize=quantize, - bettertransformer=bettertransformer, - extra_dependencies=enable_features, - dockerfile_template=dockerfile_template_path, - runtime=runtime, - container_registry=container_registry, - container_version_strategy=container_version_strategy, - ) - except Exception as err: raise err from None - - if machine: termui.echo(f"__tag__:{bento.tag}", fg="white") - elif output == "pretty": - if not get_quiet_mode() and (not push or not containerize): - termui.echo("\n" + OPENLLM_FIGLET, fg="white") - if not _previously_built: termui.echo(f"Successfully built {bento}.", fg="green") - elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg="yellow") - termui.echo( - "📖 Next steps:\n\n" - + "* Push to BentoCloud with 'bentoml push':\n" - + f" $ bentoml push {bento.tag}\n\n" - + "* Containerize your Bento with 'bentoml containerize':\n" - + f" $ bentoml containerize {bento.tag} --opt progress=plain" - + "\n\n" - + " Tip: To enable additional BentoML features for 'containerize', " - + "use '--enable-features=FEATURE[,FEATURE]' " - + "[see 'bentoml containerize -h' for more advanced usage]\n", - fg="blue", - ) - elif output == "json": termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode()) - else: termui.echo(bento.tag) - - if push: BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context) - elif containerize: - backend = t.cast("DefaultBuilder", os.getenv("BENTOML_CONTAINERIZE_BACKEND", "docker")) - try: bentoml.container.health(backend) - except subprocess.CalledProcessError: raise OpenLLMException(f"Failed to use backend {backend}") from None - try: bentoml.container.build(bento.tag, backend=backend, features=("grpc","io")) - except Exception as err: raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err - return bento + bentoml.container.health(backend) + except subprocess.CalledProcessError: + raise OpenLLMException(f"Failed to use backend {backend}") from None + try: + bentoml.container.build(bento.tag, backend=backend, features=("grpc", "io")) + except Exception as err: + raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err + return bento @cli.command() @output_option @@ -873,118 +806,87 @@ def build_command( @machine_option @click.pass_context def models_command(ctx: click.Context, output: LiteralOutput, show_available: bool, machine: bool) -> DictStrAny | None: - """List all supported models. + """List all supported models. - \b - > NOTE: '--show-available' and '-o porcelain' are mutually exclusive. + \b + > NOTE: '--show-available' and '-o porcelain' are mutually exclusive. - \b - ```bash - openllm models --show-available - ``` - """ - from .._llm import normalise_model_name + \b + ```bash + openllm models --show-available + ``` + """ + from .._llm import normalise_model_name - models = tuple(inflection.dasherize(key) for key in CONFIG_MAPPING.keys()) - if output == "porcelain": - if show_available: raise click.BadOptionUsage("--show-available", "Cannot use '--show-available' with '-o porcelain' (mutually exclusive).") - termui.echo("\n".join(models), fg="white") + models = tuple(inflection.dasherize(key) for key in CONFIG_MAPPING.keys()) + if output == "porcelain": + if show_available: raise click.BadOptionUsage("--show-available", "Cannot use '--show-available' with '-o porcelain' (mutually exclusive).") + termui.echo("\n".join(models), fg="white") + else: + failed_initialized: list[tuple[str, Exception]] = [] + + json_data: dict[str, dict[t.Literal["architecture", "model_id", "url", "installation", "cpu", "gpu", "runtime_impl"], t.Any] | t.Any] = {} + converted: list[str] = [] + for m in models: + config = AutoConfig.for_model(m) + runtime_impl: tuple[str, ...] = () + if config["model_name"] in MODEL_MAPPING_NAMES: runtime_impl += ("pt",) + if config["model_name"] in MODEL_FLAX_MAPPING_NAMES: runtime_impl += ("flax",) + if config["model_name"] in MODEL_TF_MAPPING_NAMES: runtime_impl += ("tf",) + if config["model_name"] in MODEL_VLLM_MAPPING_NAMES: runtime_impl += ("vllm",) + json_data[m] = {"architecture": config["architecture"], "model_id": config["model_ids"], "cpu": not config["requires_gpu"], "gpu": True, "runtime_impl": runtime_impl, "installation": f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config["requirements"] else "openllm",} + converted.extend([normalise_model_name(i) for i in config["model_ids"]]) + if DEBUG: + try: + AutoLLM.for_model(m, llm_config=config) + except Exception as e: + failed_initialized.append((m, e)) + + ids_in_local_store = {k: [i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k] for k in json_data.keys()} + ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v} + local_models: DictStrAny | None = None + if show_available: + local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()} + + if machine: + if show_available: json_data["local"] = local_models + return json_data + elif output == "pretty": + import tabulate + + tabulate.PRESERVE_WHITESPACE = True + # llm, architecture, url, model_id, installation, cpu, gpu, runtime_impl + data: list[str | tuple[str, str, list[str], str, t.LiteralString, t.LiteralString, tuple[LiteralRuntime, ...]]] = [] + for m, v in json_data.items(): + data.extend([(m, v["architecture"], v["model_id"], v["installation"], "❌" if not v["cpu"] else "✅", "✅", v["runtime_impl"],)]) + column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 4),] + + if len(data) == 0 and len(failed_initialized) > 0: + termui.echo("Exception found while parsing models:\n", fg="yellow") + for m, err in failed_initialized: + termui.echo(f"- {m}: ", fg="yellow", nl=False) + termui.echo(traceback.print_exception(err, limit=3), fg="red") + sys.exit(1) + + table = tabulate.tabulate(data, tablefmt="fancy_grid", headers=["LLM", "Architecture", "Models Id", "pip install", "CPU", "GPU", "Runtime"], maxcolwidths=column_widths,) + termui.echo(table, fg="white") + + if DEBUG and len(failed_initialized) > 0: + termui.echo("\nThe following models are supported but failed to initialize:\n") + for m, err in failed_initialized: + termui.echo(f"- {m}: ", fg="blue", nl=False) + termui.echo(err, fg="red") + + if show_available: + if len(ids_in_local_store) == 0: + termui.echo("No models available locally.") + ctx.exit(0) + termui.echo("The following are available in local store:", fg="magenta") + termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white") else: - failed_initialized: list[tuple[str, Exception]] = [] - - json_data: dict[str, dict[t.Literal["architecture", "model_id", "url", "installation", "cpu", "gpu", "runtime_impl"], t.Any] | t.Any] = {} - converted: list[str] = [] - for m in models: - config = AutoConfig.for_model(m) - runtime_impl: tuple[str, ...] = () - if config["model_name"] in MODEL_MAPPING_NAMES: runtime_impl += ("pt",) - if config["model_name"] in MODEL_FLAX_MAPPING_NAMES: runtime_impl += ("flax",) - if config["model_name"] in MODEL_TF_MAPPING_NAMES: runtime_impl += ("tf",) - if config["model_name"] in MODEL_VLLM_MAPPING_NAMES: runtime_impl += ("vllm",) - json_data[m] = { - "architecture": config["architecture"], - "model_id": config["model_ids"], - "cpu": not config["requires_gpu"], - "gpu": True, - "runtime_impl": runtime_impl, - "installation": f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config["requirements"] else "openllm", - } - converted.extend([normalise_model_name(i) for i in config["model_ids"]]) - if DEBUG: - try: AutoLLM.for_model(m, llm_config=config) - except Exception as e: failed_initialized.append((m, e)) - - ids_in_local_store = {k: [i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k] for k in json_data.keys()} - ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v} - local_models: DictStrAny | None = None - if show_available: - local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()} - - if machine: - if show_available: json_data["local"] = local_models - return json_data - elif output == "pretty": - import tabulate - - tabulate.PRESERVE_WHITESPACE = True - # llm, architecture, url, model_id, installation, cpu, gpu, runtime_impl - data: list[str | tuple[str, str, list[str], str, t.LiteralString, t.LiteralString, tuple[LiteralRuntime, ...]] ] = [] - for m, v in json_data.items(): - data.extend( - [ - ( - m, - v["architecture"], - v["model_id"], - v["installation"], - "❌" if not v["cpu"] else "✅", - "✅", - v["runtime_impl"], - ) - ] - ) - column_widths = [ - int(termui.COLUMNS / 12), - int(termui.COLUMNS / 6), - int(termui.COLUMNS / 4), - int(termui.COLUMNS / 12), - int(termui.COLUMNS / 12), - int(termui.COLUMNS / 12), - int(termui.COLUMNS / 4), - ] - - if len(data) == 0 and len(failed_initialized) > 0: - termui.echo("Exception found while parsing models:\n", fg="yellow") - for m, err in failed_initialized: - termui.echo(f"- {m}: ", fg="yellow", nl=False) - termui.echo(traceback.print_exception(err, limit=3), fg="red") - sys.exit(1) - - table = tabulate.tabulate( - data, - tablefmt="fancy_grid", - headers=["LLM", "Architecture", "Models Id", "pip install", "CPU", "GPU", "Runtime"], - maxcolwidths=column_widths, - ) - termui.echo(table, fg="white") - - if DEBUG and len(failed_initialized) > 0: - termui.echo("\nThe following models are supported but failed to initialize:\n") - for m, err in failed_initialized: - termui.echo(f"- {m}: ", fg="blue", nl=False) - termui.echo(err, fg="red") - - if show_available: - if len(ids_in_local_store) == 0: - termui.echo("No models available locally.") - ctx.exit(0) - termui.echo("The following are available in local store:", fg="magenta") - termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white") - else: - if show_available: json_data["local"] = local_models - termui.echo(orjson.dumps(json_data, option=orjson.OPT_INDENT_2,).decode(), fg="white") - ctx.exit(0) - + if show_available: json_data["local"] = local_models + termui.echo(orjson.dumps(json_data, option=orjson.OPT_INDENT_2,).decode(), fg="white") + ctx.exit(0) @cli.command() @model_name_argument(required=False) @@ -992,92 +894,81 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo @click.option("--include-bentos/--no-include-bentos", is_flag=True, default=False, help="Whether to also include pruning bentos.") @inject def prune_command(model_name: str | None, yes: bool, include_bentos: bool, model_store: ModelStore = Provide[BentoMLContainer.model_store], bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> None: - """Remove all saved models, (and optionally bentos) built with OpenLLM locally. + """Remove all saved models, (and optionally bentos) built with OpenLLM locally. - \b - If a model type is passed, then only prune models for that given model type. - """ - available: list[tuple[bentoml.Model | bentoml.Bento, ModelStore | BentoStore]]= [(m, model_store) for m in bentoml.models.list() if "framework" in m.info.labels and m.info.labels["framework"] == "openllm"] - if model_name is not None: available = [(m, store) for m, store in available if "model_name" in m.info.labels and m.info.labels["model_name"] == inflection.underscore(model_name)] - if include_bentos: - if model_name is not None: available += [(b, bento_store) for b in bentoml.bentos.list() if "start_name" in b.info.labels and b.info.labels["start_name"] == inflection.underscore(model_name)] - else: available += [(b, bento_store) for b in bentoml.bentos.list() if "_type" in b.info.labels and "_framework" in b.info.labels] + \b + If a model type is passed, then only prune models for that given model type. + """ + available: list[tuple[bentoml.Model | bentoml.Bento, ModelStore | BentoStore]] = [(m, model_store) for m in bentoml.models.list() if "framework" in m.info.labels and m.info.labels["framework"] == "openllm"] + if model_name is not None: available = [(m, store) for m, store in available if "model_name" in m.info.labels and m.info.labels["model_name"] == inflection.underscore(model_name)] + if include_bentos: + if model_name is not None: available += [(b, bento_store) for b in bentoml.bentos.list() if "start_name" in b.info.labels and b.info.labels["start_name"] == inflection.underscore(model_name)] + else: available += [(b, bento_store) for b in bentoml.bentos.list() if "_type" in b.info.labels and "_framework" in b.info.labels] - for store_item, store in available: - if yes: delete_confirmed = True - else: delete_confirmed = click.confirm(f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?") - if delete_confirmed: - store.delete(store_item.tag) - termui.echo(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.", fg="yellow") + for store_item, store in available: + if yes: delete_confirmed = True + else: delete_confirmed = click.confirm(f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?") + if delete_confirmed: + store.delete(store_item.tag) + termui.echo(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.", fg="yellow") +def parsing_instruction_callback(ctx: click.Context, param: click.Parameter, value: list[str] | str | None) -> tuple[str, bool | str] | list[str] | str | None: + if value is None: + return value -def parsing_instruction_callback( - ctx: click.Context, param: click.Parameter, value: list[str] | str | None -) -> tuple[str, bool | str] | list[str] | str | None: - if value is None: - return value - - if isinstance(value, list): - # we only parse --text foo bar -> --text foo and omit bar - value = value[-1] - - key, *values = value.split("=") - if not key.startswith("--"): - raise click.BadParameter(f"Invalid option format: {value}") - key = key[2:] - if len(values) == 0: - return key, True - elif len(values) == 1: - return key, values[0] - else: - raise click.BadParameter(f"Invalid option format: {value}") + if isinstance(value, list): + # we only parse --text foo bar -> --text foo and omit bar + value = value[-1] + key, *values = value.split("=") + if not key.startswith("--"): + raise click.BadParameter(f"Invalid option format: {value}") + key = key[2:] + if len(values) == 0: + return key, True + elif len(values) == 1: + return key, values[0] + else: + raise click.BadParameter(f"Invalid option format: {value}") def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal["json", "porcelain", "pretty"] = "pretty") -> t.Callable[[FC], FC]: - options = [ - click.option( - "--endpoint", - type=click.STRING, - help="OpenLLM Server endpoint, i.e: http://localhost:3000", - envvar="OPENLLM_ENDPOINT", - default="http://localhost:3000", - ), - click.option("--timeout", type=click.INT, default=30, help="Default server timeout", show_default=True), - output_option(default_value=output_value), - ] - return compose(*options)(f) if f is not None else compose(*options) - + options = [click.option("--endpoint", type=click.STRING, help="OpenLLM Server endpoint, i.e: http://localhost:3000", envvar="OPENLLM_ENDPOINT", default="http://localhost:3000",), click.option("--timeout", type=click.INT, default=30, help="Default server timeout", show_default=True), output_option(default_value=output_value),] + return compose(*options)(f) if f is not None else compose(*options) @cli.command() @click.argument("task", type=click.STRING, metavar="TASK") @shared_client_options @click.option("--agent", type=click.Choice(["hf"]), default="hf", help="Whether to interact with Agents from given Server endpoint.", show_default=True) @click.option("--remote", is_flag=True, default=False, help="Whether or not to use remote tools (inference endpoints) instead of local ones.", show_default=True) -@click.option("--opt", help="Define prompt options. " "(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]") +@click.option("--opt", help="Define prompt options. " + "(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]") def instruct_command(endpoint: str, timeout: int, agent: t.LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str: - """Instruct agents interactively for given tasks, from a terminal. + """Instruct agents interactively for given tasks, from a terminal. - \b - ```bash - $ openllm instruct --endpoint http://12.323.2.1:3000 \\ + \b + ```bash + $ openllm instruct --endpoint http://12.323.2.1:3000 \\ "Is the following `text` (in Spanish) positive or negative?" \\ --text "¡Este es un API muy agradable!" - ``` - """ - client = openllm_client.HTTPClient(endpoint, timeout=timeout) + ``` + """ + client = openllm_client.HTTPClient(endpoint, timeout=timeout) - try: client.call("metadata") - except http.client.BadStatusLine: raise click.ClickException(f"{endpoint} is neither a HTTP server nor reachable.") from None - if agent == "hf": - if not is_transformers_supports_agent(): raise click.UsageError("Transformers version should be at least 4.29 to support HfAgent. Upgrade with 'pip install -U transformers'") - _memoized = {k: v[0] for k, v in _memoized.items() if v} - client._hf_agent.set_stream(logger.info) - if output != "porcelain": termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg="magenta") - result = client.ask_agent(task, agent_type=agent, return_code=False, remote=remote, **_memoized) - if output == "json": termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg="white") - else: termui.echo(result, fg="white") - return result - else: raise click.BadOptionUsage("agent", f"Unknown agent type {agent}") + try: + client.call("metadata") + except http.client.BadStatusLine: + raise click.ClickException(f"{endpoint} is neither a HTTP server nor reachable.") from None + if agent == "hf": + if not is_transformers_supports_agent(): raise click.UsageError("Transformers version should be at least 4.29 to support HfAgent. Upgrade with 'pip install -U transformers'") + _memoized = {k: v[0] for k, v in _memoized.items() if v} + client._hf_agent.set_stream(logger.info) + if output != "porcelain": termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg="magenta") + result = client.ask_agent(task, agent_type=agent, return_code=False, remote=remote, **_memoized) + if output == "json": termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg="white") + else: termui.echo(result, fg="white") + return result + else: + raise click.BadOptionUsage("agent", f"Unknown agent type {agent}") @cli.command() @shared_client_options(output_value="json") @@ -1086,25 +977,29 @@ def instruct_command(endpoint: str, timeout: int, agent: t.LiteralString, output @machine_option @click.pass_context def embed_command(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, server_type: t.Literal["http", "grpc"], output: LiteralOutput, machine: bool) -> EmbeddingsOutput | None: - """Get embeddings interactively, from a terminal. + """Get embeddings interactively, from a terminal. - \b - ```bash - $ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?" - ``` - """ - client = openllm_client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm_client.GrpcClient(endpoint, timeout=timeout) - try: gen_embed = client.embed(text) - except ValueError: raise click.ClickException(f"Endpoint {endpoint} does not support embeddings.") from None - if machine: return gen_embed - elif output == "pretty": - termui.echo("Generated embeddings: ", fg="magenta", nl=False) - termui.echo(gen_embed.embeddings, fg="white") - termui.echo("\nNumber of tokens: ", fg="magenta", nl=False) - termui.echo(gen_embed.num_tokens, fg="white") - elif output == "json": termui.echo(orjson.dumps(bentoml_cattr.unstructure(gen_embed), option=orjson.OPT_INDENT_2).decode(), fg="white") - else: termui.echo(gen_embed.embeddings, fg="white") - ctx.exit(0) + \b + ```bash + $ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?" + ``` + """ + client = openllm_client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm_client.GrpcClient(endpoint, timeout=timeout) + try: + gen_embed = client.embed(text) + except ValueError: + raise click.ClickException(f"Endpoint {endpoint} does not support embeddings.") from None + if machine: return gen_embed + elif output == "pretty": + termui.echo("Generated embeddings: ", fg="magenta", nl=False) + termui.echo(gen_embed.embeddings, fg="white") + termui.echo("\nNumber of tokens: ", fg="magenta", nl=False) + termui.echo(gen_embed.num_tokens, fg="white") + elif output == "json": + termui.echo(orjson.dumps(bentoml_cattr.unstructure(gen_embed), option=orjson.OPT_INDENT_2).decode(), fg="white") + else: + termui.echo(gen_embed.embeddings, fg="white") + ctx.exit(0) @cli.command() @shared_client_options @@ -1113,94 +1008,104 @@ def embed_command(ctx: click.Context, text: tuple[str, ...], endpoint: str, time @click.option("--sampling-params", help="Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]") @click.pass_context def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, server_type: t.Literal["http", "grpc"], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any) -> None: - """Ask a LLM interactively, from a terminal. + """Ask a LLM interactively, from a terminal. - \b - ```bash - $ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?" - ``` - """ - _memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v} - if server_type == "grpc": endpoint = re.sub(r"http://", "", endpoint) - client = openllm_client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm_client.GrpcClient(endpoint, timeout=timeout) - input_fg, generated_fg = "magenta", "cyan" - if output != "porcelain": - termui.echo("==Input==\n", fg="white") - termui.echo(f"{prompt}", fg=input_fg) - res = client.query(prompt, return_response="raw", **{**client.configuration, **_memoized}) - if output == "pretty": - response = client.llm.postprocess_generate(prompt, res["responses"]) - termui.echo("\n\n==Responses==\n", fg="white") - termui.echo(response, fg=generated_fg) - elif output == "json": termui.echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white") - else: termui.echo(res["responses"], fg="white") - ctx.exit(0) + \b + ```bash + $ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?" + ``` + """ + _memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v} + if server_type == "grpc": endpoint = re.sub(r"http://", "", endpoint) + client = openllm_client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm_client.GrpcClient(endpoint, timeout=timeout) + input_fg, generated_fg = "magenta", "cyan" + if output != "porcelain": + termui.echo("==Input==\n", fg="white") + termui.echo(f"{prompt}", fg=input_fg) + res = client.query(prompt, return_response="raw", **{**client.configuration, **_memoized}) + if output == "pretty": + response = client.llm.postprocess_generate(prompt, res["responses"]) + termui.echo("\n\n==Responses==\n", fg="white") + termui.echo(response, fg=generated_fg) + elif output == "json": + termui.echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white") + else: + termui.echo(res["responses"], fg="white") + ctx.exit(0) def load_notebook_metadata() -> DictStrAny: - with open(os.path.join(os.path.dirname(playground.__file__), "_meta.yml"), "r") as f: content = yaml.safe_load(f) - if not all("description" in k for k in content.values()): raise ValueError("Invalid metadata file. All entries must have a 'description' key.") - return content + with open(os.path.join(os.path.dirname(playground.__file__), "_meta.yml"), "r") as f: + content = yaml.safe_load(f) + if not all("description" in k for k in content.values()): raise ValueError("Invalid metadata file. All entries must have a 'description' key.") + return content @cli.command() @click.argument("output-dir", default=None, required=False) @click.option("--port", envvar="JUPYTER_PORT", show_envvar=True, show_default=True, default=8888, help="Default port for Jupyter server") @click.pass_context def playground_command(ctx: click.Context, output_dir: str | None, port: int) -> None: - """OpenLLM Playground. + """OpenLLM Playground. - A collections of notebooks to explore the capabilities of OpenLLM. - This includes notebooks for fine-tuning, inference, and more. + A collections of notebooks to explore the capabilities of OpenLLM. + This includes notebooks for fine-tuning, inference, and more. - All of the script available in the playground can also be run directly as a Python script: - For example: + All of the script available in the playground can also be run directly as a Python script: + For example: - \b - ```bash - python -m openllm.playground.falcon_tuned --help - ``` + \b + ```bash + python -m openllm.playground.falcon_tuned --help + ``` - \b - > Note: This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"' - """ - if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available(): - raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'") - metadata = load_notebook_metadata() - _temp_dir = False - if output_dir is None: - _temp_dir = True - output_dir = tempfile.mkdtemp(prefix="openllm-playground-") - else: os.makedirs(os.path.abspath(os.path.expandvars(os.path.expanduser(output_dir))), exist_ok=True) + \b + > Note: This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"' + """ + if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available(): + raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'") + metadata = load_notebook_metadata() + _temp_dir = False + if output_dir is None: + _temp_dir = True + output_dir = tempfile.mkdtemp(prefix="openllm-playground-") + else: + os.makedirs(os.path.abspath(os.path.expandvars(os.path.expanduser(output_dir))), exist_ok=True) - termui.echo("The playground notebooks will be saved to: " + os.path.abspath(output_dir), fg="blue") - for module in pkgutil.iter_modules(playground.__path__): - if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + ".ipynb")): - logger.debug("Skipping: %s (%s)", module.name, "File already exists" if not module.ispkg else f"{module.name} is a module") - continue - if not isinstance(module.module_finder, importlib.machinery.FileFinder): continue - termui.echo("Generating notebook for: " + module.name, fg="magenta") - markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]["description"]) - f = jupytext.read(os.path.join(module.module_finder.path, module.name + ".py")) - f.cells.insert(0, markdown_cell) - jupytext.write(f, os.path.join(output_dir, module.name + ".ipynb"), fmt="notebook") - try: subprocess.check_output([sys.executable, "-m", "jupyter", "notebook", "--notebook-dir", output_dir, "--port", str(port), "--no-browser", "--debug"]) - except subprocess.CalledProcessError as e: - termui.echo(e.output, fg="red") - raise click.ClickException(f"Failed to start a jupyter server:\n{e}") from None - except KeyboardInterrupt: - termui.echo("\nShutting down Jupyter server...", fg="yellow") - if _temp_dir: termui.echo("Note: You can access the generated notebooks in: " + output_dir, fg="blue") - ctx.exit(0) + termui.echo("The playground notebooks will be saved to: " + os.path.abspath(output_dir), fg="blue") + for module in pkgutil.iter_modules(playground.__path__): + if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + ".ipynb")): + logger.debug("Skipping: %s (%s)", module.name, "File already exists" if not module.ispkg else f"{module.name} is a module") + continue + if not isinstance(module.module_finder, importlib.machinery.FileFinder): continue + termui.echo("Generating notebook for: " + module.name, fg="magenta") + markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]["description"]) + f = jupytext.read(os.path.join(module.module_finder.path, module.name + ".py")) + f.cells.insert(0, markdown_cell) + jupytext.write(f, os.path.join(output_dir, module.name + ".ipynb"), fmt="notebook") + try: + subprocess.check_output([sys.executable, "-m", "jupyter", "notebook", "--notebook-dir", output_dir, "--port", str(port), "--no-browser", "--debug"]) + except subprocess.CalledProcessError as e: + termui.echo(e.output, fg="red") + raise click.ClickException(f"Failed to start a jupyter server:\n{e}") from None + except KeyboardInterrupt: + termui.echo("\nShutting down Jupyter server...", fg="yellow") + if _temp_dir: termui.echo("Note: You can access the generated notebooks in: " + output_dir, fg="blue") + ctx.exit(0) _EXT_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), "ext")) class Extensions(click.MultiCommand): - def list_commands(self, ctx: click.Context) -> list[str]: return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith(".py") and not filename.startswith("__")]) - def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: - try: mod = __import__(f"openllm.cli.ext.{cmd_name}", None, None, ["cli"]) - except ImportError: return None - return mod.cli + def list_commands(self, ctx: click.Context) -> list[str]: + return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith(".py") and not filename.startswith("__")]) -@cli.group(cls=Extensions, name="ext", aliases=["utils"], help="Extension for OpenLLM CLI.") -def ext_command() -> None: ... + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: + try: + mod = __import__(f"openllm.cli.ext.{cmd_name}", None, None, ["cli"]) + except ImportError: + return None + return mod.cli + +@cli.group(cls=Extensions, name="ext", aliases=["utils"]) +def ext_command() -> None: + """Extension for OpenLLM CLI.""" if __name__ == "__main__": cli() diff --git a/src/openllm/cli/ext/__init__.py b/src/openllm/cli/ext/__init__.py index 695382f5..93977acf 100644 --- a/src/openllm/cli/ext/__init__.py +++ b/src/openllm/cli/ext/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """OpenLLM CLI Extension. The following directory contains all possible extensions for OpenLLM CLI diff --git a/src/openllm/cli/ext/build_base_container.py b/src/openllm/cli/ext/build_base_container.py index 1101459c..e77f9de0 100644 --- a/src/openllm/cli/ext/build_base_container.py +++ b/src/openllm/cli/ext/build_base_container.py @@ -24,12 +24,11 @@ from .. import termui from .._factory import machine_option if t.TYPE_CHECKING: - from openllm.bundle.oci import LiteralContainerRegistry - from openllm.bundle.oci import LiteralContainerVersionStrategy + from openllm.bundle.oci import LiteralContainerRegistry + from openllm.bundle.oci import LiteralContainerVersionStrategy -@click.command("build_base_container", - context_settings=termui.CONTEXT_SETTINGS, - help="""Base image builder for BentoLLM. +@click.command( + "build_base_container", context_settings=termui.CONTEXT_SETTINGS, help="""Base image builder for BentoLLM. By default, the base image will include custom kernels (PagedAttention via vllm, FlashAttention-v2, etc.) built with CUDA 11.8, Python 3.9 on Ubuntu22.04. @@ -40,12 +39,13 @@ if t.TYPE_CHECKING: This command is only useful for debugging and for building custom base image for extending BentoML with custom base images and custom kernels. Note that we already release images on our CI to ECR and GHCR, so you don't need to build it yourself. - """) + """ +) @click.option("--registry", multiple=True, type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)), help="Target registry to create image tag on.", default=None) @click.option("--version-strategy", type=click.Choice(["release", "latest", "nightly"]), default="nightly", help="Version strategy to use for tagging the image.") @click.option("--push/--no-push", help="Whether to push to remote repository", is_flag=True, default=False) @machine_option def cli(registry: tuple[LiteralContainerRegistry, ...] | None, version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]: - mapping = openllm.bundle.build_container(registry, version_strategy, push, machine) - if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white") - return mapping + mapping = openllm.bundle.build_container(registry, version_strategy, push, machine) + if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white") + return mapping diff --git a/src/openllm/cli/ext/dive_bentos.py b/src/openllm/cli/ext/dive_bentos.py index c3d05a5d..bfc37eee 100644 --- a/src/openllm/cli/ext/dive_bentos.py +++ b/src/openllm/cli/ext/dive_bentos.py @@ -28,7 +28,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer from .. import termui if t.TYPE_CHECKING: - from bentoml._internal.bento import BentoStore + from bentoml._internal.bento import BentoStore @click.command("dive_bentos", context_settings=termui.CONTEXT_SETTINGS) @click.argument("bento", type=str) @@ -36,12 +36,15 @@ if t.TYPE_CHECKING: @click.pass_context @inject def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None: - """Dive into a BentoLLM. This is synonymous to cd $(b get : -o path).""" - try: bentomodel = _bento_store.get(bento) - except bentoml.exceptions.NotFound: ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.") - if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle": ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.") - if machine: return bentomodel.path - # copy and paste this into a new shell - if psutil.WINDOWS: subprocess.check_call([shutil.which("dir") or "dir"], cwd=bentomodel.path) - else:subprocess.check_call([shutil.which("tree") or "tree"], cwd=bentomodel.path) - ctx.exit(0) + """Dive into a BentoLLM. This is synonymous to cd $(b get : -o path).""" + try: + bentomodel = _bento_store.get(bento) + except bentoml.exceptions.NotFound: + ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.") + if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle": + ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.") + if machine: return bentomodel.path + # copy and paste this into a new shell + if psutil.WINDOWS: subprocess.check_call([shutil.which("dir") or "dir"], cwd=bentomodel.path) + else: subprocess.check_call([shutil.which("tree") or "tree"], cwd=bentomodel.path) + ctx.exit(0) diff --git a/src/openllm/cli/ext/get_containerfile.py b/src/openllm/cli/ext/get_containerfile.py index c41354de..920f53a8 100644 --- a/src/openllm/cli/ext/get_containerfile.py +++ b/src/openllm/cli/ext/get_containerfile.py @@ -29,28 +29,28 @@ from .. import termui from ...utils import bentoml_cattr if t.TYPE_CHECKING: - from bentoml._internal.bento import BentoStore + from bentoml._internal.bento import BentoStore @click.command("get_containerfile", context_settings=termui.CONTEXT_SETTINGS, help="Return Containerfile of any given Bento.") @click.argument("bento", type=str) @click.pass_context @inject def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str: - try: bentomodel = _bento_store.get(bento) - except bentoml.exceptions.NotFound: ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.") - # The logic below are similar to bentoml._internal.container.construct_containerfile - with open(bentomodel.path_of("bento.yaml"), "r") as f: - options = BentoInfo.from_yaml_file(f) - # NOTE: dockerfile_template is already included in the - # Dockerfile inside bento, and it is not relevant to - # construct_containerfile. Hence it is safe to set it to None here. - # See https://github.com/bentoml/BentoML/issues/3399. - docker_attrs = bentoml_cattr.unstructure(options.docker) - # NOTE: if users specify a dockerfile_template, we will - # save it to /env/docker/Dockerfile.template. This is necessary - # for the reconstruction of the Dockerfile. - if "dockerfile_template" in docker_attrs and docker_attrs["dockerfile_template"] is not None: docker_attrs["dockerfile_template"] = "env/docker/Dockerfile.template" - termui.echo(generate_containerfile( - docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True, - ), fg="white") - return bentomodel.path + try: + bentomodel = _bento_store.get(bento) + except bentoml.exceptions.NotFound: + ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.") + # The logic below are similar to bentoml._internal.container.construct_containerfile + with open(bentomodel.path_of("bento.yaml"), "r") as f: + options = BentoInfo.from_yaml_file(f) + # NOTE: dockerfile_template is already included in the + # Dockerfile inside bento, and it is not relevant to + # construct_containerfile. Hence it is safe to set it to None here. + # See https://github.com/bentoml/BentoML/issues/3399. + docker_attrs = bentoml_cattr.unstructure(options.docker) + # NOTE: if users specify a dockerfile_template, we will + # save it to /env/docker/Dockerfile.template. This is necessary + # for the reconstruction of the Dockerfile. + if "dockerfile_template" in docker_attrs and docker_attrs["dockerfile_template"] is not None: docker_attrs["dockerfile_template"] = "env/docker/Dockerfile.template" + termui.echo(generate_containerfile(docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True,), fg="white") + return bentomodel.path diff --git a/src/openllm/cli/ext/get_prompt.py b/src/openllm/cli/ext/get_prompt.py index c82fb724..8ad89a17 100644 --- a/src/openllm/cli/ext/get_prompt.py +++ b/src/openllm/cli/ext/get_prompt.py @@ -26,7 +26,7 @@ from .. import termui from ..._prompt import process_prompt if t.TYPE_CHECKING: - from ..entrypoint import LiteralOutput + from ..entrypoint import LiteralOutput @click.command("get_prompt", context_settings=termui.CONTEXT_SETTINGS) @click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])) @@ -37,27 +37,29 @@ if t.TYPE_CHECKING: @click.option("--opt", help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]") @click.pass_context def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, output: LiteralOutput, machine: bool, _memoized: dict[str, t.Any], **_: t.Any) -> str | None: - """Get the default prompt used by OpenLLM.""" - module = openllm.utils.EnvVarMixin(model_name).module - _memoized = {k: v[0] for k, v in _memoized.items() if v} - try: - template = getattr(module, "DEFAULT_PROMPT_TEMPLATE", None) - prompt_mapping = getattr(module, "PROMPT_MAPPING", None) - if template is None: raise click.BadArgumentUsage(f"model {model_name} does not have a default prompt template") from None - if callable(template): - if format is None: - if not hasattr(module, "PROMPT_MAPPING") or module.PROMPT_MAPPING is None: raise RuntimeError("Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.") - raise click.BadOptionUsage("format", f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})") - if prompt_mapping is None: raise click.BadArgumentUsage(f"Failed to fine prompt mapping while the default prompt for {model_name} is a callable.") from None - if format not in prompt_mapping: raise click.BadOptionUsage("format", f"Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})") - _prompt_template = template(format) - else: _prompt_template = template - fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized) - if machine: return repr(fully_formatted) - elif output == "porcelain": termui.echo(repr(fully_formatted), fg="white") - elif output == "json": termui.echo(orjson.dumps({"prompt": fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg="white") - else: - termui.echo(f"== Prompt for {model_name} ==\n", fg="magenta") - termui.echo(fully_formatted, fg="white") - except AttributeError: raise click.ClickException(f"Failed to determine a default prompt template for {model_name}.") from None - ctx.exit(0) + """Get the default prompt used by OpenLLM.""" + module = openllm.utils.EnvVarMixin(model_name).module + _memoized = {k: v[0] for k, v in _memoized.items() if v} + try: + template = getattr(module, "DEFAULT_PROMPT_TEMPLATE", None) + prompt_mapping = getattr(module, "PROMPT_MAPPING", None) + if template is None: raise click.BadArgumentUsage(f"model {model_name} does not have a default prompt template") from None + if callable(template): + if format is None: + if not hasattr(module, "PROMPT_MAPPING") or module.PROMPT_MAPPING is None: raise RuntimeError("Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.") + raise click.BadOptionUsage("format", f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})") + if prompt_mapping is None: raise click.BadArgumentUsage(f"Failed to fine prompt mapping while the default prompt for {model_name} is a callable.") from None + if format not in prompt_mapping: raise click.BadOptionUsage("format", f"Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})") + _prompt_template = template(format) + else: + _prompt_template = template + fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized) + if machine: return repr(fully_formatted) + elif output == "porcelain": termui.echo(repr(fully_formatted), fg="white") + elif output == "json": termui.echo(orjson.dumps({"prompt": fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg="white") + else: + termui.echo(f"== Prompt for {model_name} ==\n", fg="magenta") + termui.echo(fully_formatted, fg="white") + except AttributeError: + raise click.ClickException(f"Failed to determine a default prompt template for {model_name}.") from None + ctx.exit(0) diff --git a/src/openllm/cli/ext/list_bentos.py b/src/openllm/cli/ext/list_bentos.py index fc1d58e6..4056c925 100644 --- a/src/openllm/cli/ext/list_bentos.py +++ b/src/openllm/cli/ext/list_bentos.py @@ -26,9 +26,9 @@ from .. import termui @click.command("list_bentos", context_settings=termui.CONTEXT_SETTINGS) @click.pass_context def cli(ctx: click.Context) -> None: - """List available bentos built by OpenLLM.""" - _local_bentos = {str(i.tag): i.info.labels["start_name"] for i in bentoml.list() if "start_name" in i.info.labels} - mapping = {k: [tag for tag, name in _local_bentos.items() if name == k] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())} - mapping = {k: v for k, v in mapping.items() if v} - termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white") - ctx.exit(0) + """List available bentos built by OpenLLM.""" + _local_bentos = {str(i.tag): i.info.labels["start_name"] for i in bentoml.list() if "start_name" in i.info.labels} + mapping = {k: [tag for tag, name in _local_bentos.items() if name == k] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())} + mapping = {k: v for k, v in mapping.items() if v} + termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white") + ctx.exit(0) diff --git a/src/openllm/cli/ext/list_models.py b/src/openllm/cli/ext/list_models.py index 796bd2ca..e7c28087 100644 --- a/src/openllm/cli/ext/list_models.py +++ b/src/openllm/cli/ext/list_models.py @@ -26,16 +26,16 @@ from .. import termui from .._factory import model_name_argument if t.TYPE_CHECKING: - from ..._types import DictStrAny + from ..._types import DictStrAny @click.command("list_models", context_settings=termui.CONTEXT_SETTINGS) @model_name_argument(required=False) def cli(model_name: str | None) -> DictStrAny: - """This is equivalent to openllm models --show-available less the nice table.""" - models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()) - ids_in_local_store = {k: [i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k] for k in models} - if model_name is not None: ids_in_local_store = {k: [i for i in v if "model_name" in i.info.labels and i.info.labels["model_name"] == inflection.dasherize(model_name)] for k,v in ids_in_local_store.items()} - ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v} - local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()} - termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white") - return local_models + """This is equivalent to openllm models --show-available less the nice table.""" + models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()) + ids_in_local_store = {k: [i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k] for k in models} + if model_name is not None: ids_in_local_store = {k: [i for i in v if "model_name" in i.info.labels and i.info.labels["model_name"] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items()} + ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v} + local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()} + termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white") + return local_models diff --git a/src/openllm/cli/termui.py b/src/openllm/cli/termui.py index 89e74cd5..9db513ed 100644 --- a/src/openllm/cli/termui.py +++ b/src/openllm/cli/termui.py @@ -23,13 +23,11 @@ from ..utils import get_debug_mode from ..utils import get_quiet_mode def echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.Any) -> None: - attrs["fg"], call = fg if not get_debug_mode() else None, click.echo if not _with_style else click.secho - if not get_quiet_mode(): call(text, **attrs) - + attrs["fg"], call = fg if not get_debug_mode() else None, click.echo if not _with_style else click.secho + if not get_quiet_mode(): call(text, **attrs) COLUMNS = int(os.getenv("COLUMNS", str(120))) CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"], "max_content_width": COLUMNS, "token_normalize_func": inflection.underscore} - __all__ = ["echo", "COLUMNS", "CONTEXT_SETTINGS"] diff --git a/src/openllm/client.py b/src/openllm/client.py index edd4defc..24057c70 100644 --- a/src/openllm/client.py +++ b/src/openllm/client.py @@ -24,23 +24,25 @@ import importlib import itertools import typing as t -_import_structure: dict[str, list[str]] = { - "runtimes.grpc": ["AsyncGrpcClient", "GrpcClient"], - "runtimes.http": ["AsyncHTTPClient", "HTTPClient"], -} +_import_structure: dict[str, list[str]] = {"runtimes.grpc": ["AsyncGrpcClient", "GrpcClient"], "runtimes.http": ["AsyncHTTPClient", "HTTPClient"],} if t.TYPE_CHECKING: - from openllm_client import AsyncGrpcClient as AsyncGrpcClient - from openllm_client import AsyncHTTPClient as AsyncHTTPClient - from openllm_client import GrpcClient as GrpcClient - from openllm_client import HTTPClient as HTTPClient + from openllm_client import AsyncGrpcClient as AsyncGrpcClient + from openllm_client import AsyncHTTPClient as AsyncHTTPClient + from openllm_client import GrpcClient as GrpcClient + from openllm_client import HTTPClient as HTTPClient _module = "openllm_client" __all__ = list(itertools.chain.from_iterable(_import_structure.values())) -def __dir__() -> list[str]: return sorted(__all__) + +def __dir__() -> list[str]: + return sorted(__all__) + def __getattr__(name: str) -> t.Any: - if name in _import_structure: return importlib.import_module(f".{name}", _module) - try: module = next(module for module, attrs in _import_structure.items() if name in attrs) - except StopIteration: raise AttributeError(f"module {_module} has no attribute {name}") from None - return getattr(importlib.import_module(f".{module}", _module), name) + if name in _import_structure: return importlib.import_module(f".{name}", _module) + try: + module = next(module for module, attrs in _import_structure.items() if name in attrs) + except StopIteration: + raise AttributeError(f"module {_module} has no attribute {name}") from None + return getattr(importlib.import_module(f".{module}", _module), name) diff --git a/src/openllm/exceptions.py b/src/openllm/exceptions.py index 84bcdeff..b7409851 100644 --- a/src/openllm/exceptions.py +++ b/src/openllm/exceptions.py @@ -17,30 +17,25 @@ from __future__ import annotations import bentoml class OpenLLMException(bentoml.exceptions.BentoMLException): - """Base class for all OpenLLM exceptions. This extends BentoMLException.""" - + """Base class for all OpenLLM exceptions. This extends BentoMLException.""" class GpuNotAvailableError(OpenLLMException): - """Raised when there is no GPU available in given system.""" - + """Raised when there is no GPU available in given system.""" class ValidationError(OpenLLMException): - """Raised when a validation fails.""" - + """Raised when a validation fails.""" class ForbiddenAttributeError(OpenLLMException): - """Raised when using an _internal field.""" - + """Raised when using an _internal field.""" class MissingAnnotationAttributeError(OpenLLMException): - """Raised when a field under openllm.LLMConfig is missing annotations.""" - + """Raised when a field under openllm.LLMConfig is missing annotations.""" class MissingDependencyError(BaseException): - """Raised when a dependency is missing.""" + """Raised when a dependency is missing.""" class Error(BaseException): - """To be used instead of naked raise.""" + """To be used instead of naked raise.""" class FineTuneStrategyNotSupportedError(OpenLLMException): - """Raised when a fine-tune strategy is not supported for given LLM.""" + """Raised when a fine-tune strategy is not supported for given LLM.""" diff --git a/src/openllm/models/__init__.py b/src/openllm/models/__init__.py index 238aad88..64bee6d3 100644 --- a/src/openllm/models/__init__.py +++ b/src/openllm/models/__init__.py @@ -23,20 +23,21 @@ _MODELS: set[str] = {'auto', 'baichuan', 'chatglm', 'dolly_v2', 'falcon', 'flan_ # fmt: on if t.TYPE_CHECKING: - # fmt: off - # update-models-import.py: start types - from . import auto as auto - from . import baichuan as baichuan - from . import chatglm as chatglm - from . import dolly_v2 as dolly_v2 - from . import falcon as falcon - from . import flan_t5 as flan_t5 - from . import gpt_neox as gpt_neox - from . import llama as llama - from . import mpt as mpt - from . import opt as opt - from . import stablelm as stablelm - from . import starcoder as starcoder - # update-models-import.py: stop types - # fmt: on -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS}, module_spec=__spec__) + # fmt: off + # update-models-import.py: start types + from . import auto as auto + from . import baichuan as baichuan + from . import chatglm as chatglm + from . import dolly_v2 as dolly_v2 + from . import falcon as falcon + from . import flan_t5 as flan_t5 + from . import gpt_neox as gpt_neox + from . import llama as llama + from . import mpt as mpt + from . import opt as opt + from . import stablelm as stablelm + from . import starcoder as starcoder + # update-models-import.py: stop types + # fmt: on +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS}, module_spec=__spec__) diff --git a/src/openllm/models/auto/__init__.py b/src/openllm/models/auto/__init__.py index 95ee087e..ead2edbc 100644 --- a/src/openllm/models/auto/__init__.py +++ b/src/openllm/models/auto/__init__.py @@ -20,51 +20,63 @@ from ...utils import is_flax_available from ...utils import is_tf_available from ...utils import is_torch_available from ...utils import is_vllm_available -_import_structure: dict[str, list[str]] = { - "configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"], - "modeling_auto": ["MODEL_MAPPING_NAMES"], - "modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"], - "modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"], - "modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"], -} + +_import_structure: dict[str, list[str]] = {"configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"], "modeling_auto": ["MODEL_MAPPING_NAMES"], "modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"], "modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"], "modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"],} try: - if not is_torch_available(): raise openllm.exceptions.MissingDependencyError -except openllm.exceptions.MissingDependencyError: pass -else: _import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"]) + if not is_torch_available(): raise openllm.exceptions.MissingDependencyError +except openllm.exceptions.MissingDependencyError: + pass +else: + _import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"]) try: - if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError -except openllm.exceptions.MissingDependencyError: pass -else: _import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"]) + if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError +except openllm.exceptions.MissingDependencyError: + pass +else: + _import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"]) try: - if not is_flax_available(): raise openllm.exceptions.MissingDependencyError -except openllm.exceptions.MissingDependencyError: pass -else: _import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"]) + if not is_flax_available(): raise openllm.exceptions.MissingDependencyError +except openllm.exceptions.MissingDependencyError: + pass +else: + _import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"]) try: - if not is_tf_available(): raise openllm.exceptions.MissingDependencyError -except openllm.exceptions.MissingDependencyError: pass -else: _import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"]) + if not is_tf_available(): raise openllm.exceptions.MissingDependencyError +except openllm.exceptions.MissingDependencyError: + pass +else: + _import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"]) if t.TYPE_CHECKING: - from .configuration_auto import CONFIG_MAPPING as CONFIG_MAPPING - from .configuration_auto import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES - from .configuration_auto import AutoConfig as AutoConfig - from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES - from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES - from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES - from .modeling_vllm_auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES - try: - if not is_torch_available(): raise openllm.exceptions.MissingDependencyError - except openllm.exceptions.MissingDependencyError: pass - else: from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM - try: - if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError - except openllm.exceptions.MissingDependencyError: pass - else: from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM - try: - if not is_flax_available(): raise openllm.exceptions.MissingDependencyError - except openllm.exceptions.MissingDependencyError: pass - else: from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM - try: - if not is_tf_available(): raise openllm.exceptions.MissingDependencyError - except openllm.exceptions.MissingDependencyError: pass - else: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_auto import CONFIG_MAPPING as CONFIG_MAPPING + from .configuration_auto import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES + from .configuration_auto import AutoConfig as AutoConfig + from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES + from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES + from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES + from .modeling_vllm_auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES + try: + if not is_torch_available(): raise openllm.exceptions.MissingDependencyError + except openllm.exceptions.MissingDependencyError: + pass + else: + from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM + try: + if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError + except openllm.exceptions.MissingDependencyError: + pass + else: + from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM + try: + if not is_flax_available(): raise openllm.exceptions.MissingDependencyError + except openllm.exceptions.MissingDependencyError: + pass + else: + from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM + try: + if not is_tf_available(): raise openllm.exceptions.MissingDependencyError + except openllm.exceptions.MissingDependencyError: + pass + else: + from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/auto/configuration_auto.py b/src/openllm/models/auto/configuration_auto.py index 691890fc..c48fe0b8 100644 --- a/src/openllm/models/auto/configuration_auto.py +++ b/src/openllm/models/auto/configuration_auto.py @@ -18,77 +18,82 @@ import inflection import openllm from ...utils import ReprMixin if t.TYPE_CHECKING: - import types - from collections import _odict_items, _odict_keys, _odict_values - ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]] - ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]] - ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]] - ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]] + import types + from collections import _odict_items, _odict_keys, _odict_values + ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]] + ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]] + ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]] + ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]] else: - ConfigKeysView = ConfigValuesView = ConfigItemsView = t.Any - ConfigOrderedDict = OrderedDict + ConfigKeysView = ConfigValuesView = ConfigItemsView = t.Any + ConfigOrderedDict = OrderedDict # NOTE: This is the entrypoint when adding new model config -CONFIG_MAPPING_NAMES = OrderedDict( - [ - ("chatglm", "ChatGLMConfig"), - ("dolly_v2", "DollyV2Config"), - ("falcon", "FalconConfig"), - ("flan_t5", "FlanT5Config"), - ("gpt_neox", "GPTNeoXConfig"), - ("llama", "LlamaConfig"), - ("mpt", "MPTConfig"), - ("opt", "OPTConfig"), - ("stablelm", "StableLMConfig"), - ("starcoder", "StarCoderConfig"), - ("baichuan", "BaichuanConfig"), - ] -) +CONFIG_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLMConfig"), ("dolly_v2", "DollyV2Config"), ("falcon", "FalconConfig"), ("flan_t5", "FlanT5Config"), ("gpt_neox", "GPTNeoXConfig"), ("llama", "LlamaConfig"), ("mpt", "MPTConfig"), ("opt", "OPTConfig"), ("stablelm", "StableLMConfig"), ("starcoder", "StarCoderConfig"), ("baichuan", "BaichuanConfig"),]) + class _LazyConfigMapping(ConfigOrderedDict, ReprMixin): - def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]): - self._mapping = mapping - self._extra_content: dict[str, t.Any] = {} - self._modules: dict[str, types.ModuleType] = {} - def __getitem__(self, key: str) -> t.Any: - if key in self._extra_content: return self._extra_content[key] - if key not in self._mapping: - if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key)) - raise KeyError(key) - value, module_name = self._mapping[key], inflection.underscore(key) - if module_name not in self._modules: self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module - if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value) - # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level. - return getattr(openllm, value) - @property - def __repr_keys__(self) -> set[str]: return set(self._mapping.keys()) - def __repr__(self) -> str: return ReprMixin.__repr__(self) - def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]: yield from self._mapping.items() - def keys(self): return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys())) - def values(self): return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())) - def items(self): return t.cast(ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())) - def __iter__(self): return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) - def __contains__(self, item: t.Any): return item in self._mapping or item in self._extra_content - def register(self, key: str, value: t.Any): - if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.") - self._extra_content[key] = value + def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]): + self._mapping = mapping + self._extra_content: dict[str, t.Any] = {} + self._modules: dict[str, types.ModuleType] = {} + + def __getitem__(self, key: str) -> t.Any: + if key in self._extra_content: return self._extra_content[key] + if key not in self._mapping: + if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key)) + raise KeyError(key) + value, module_name = self._mapping[key], inflection.underscore(key) + if module_name not in self._modules: self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module + if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value) + # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level. + return getattr(openllm, value) + + @property + def __repr_keys__(self) -> set[str]: + return set(self._mapping.keys()) + + def __repr__(self) -> str: + return ReprMixin.__repr__(self) + + def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]: + yield from self._mapping.items() + + def keys(self): + return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys())) + + def values(self): + return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())) + + def items(self): + return t.cast(ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())) + + def __iter__(self): + return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) + + def __contains__(self, item: t.Any): + return item in self._mapping or item in self._extra_content + + def register(self, key: str, value: t.Any): + if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.") + self._extra_content[key] = value + CONFIG_MAPPING: dict[str, type[openllm.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES) # The below handle special alias when we call underscore to the name directly # without processing camelcase first. -CONFIG_NAME_ALIASES: dict[str, str] = { - "chat_glm": "chatglm", - "stable_lm": "stablelm", - "star_coder": "starcoder", - "gpt_neo_x": "gpt_neox", -} +CONFIG_NAME_ALIASES: dict[str, str] = {"chat_glm": "chatglm", "stable_lm": "stablelm", "star_coder": "starcoder", "gpt_neo_x": "gpt_neox",} + class AutoConfig: - def __init__(self, *_: t.Any, **__: t.Any): raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.") - @classmethod - def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig: - model_name = inflection.underscore(model_name) - if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs) - raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.") - @classmethod - def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]: - model_name = inflection.underscore(name) - if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name] - if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name] - raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.") + def __init__(self, *_: t.Any, **__: t.Any): + raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.") + + @classmethod + def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig: + model_name = inflection.underscore(model_name) + if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs) + raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.") + + @classmethod + def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]: + model_name = inflection.underscore(name) + if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name] + if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name] + raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.") diff --git a/src/openllm/models/auto/factory.py b/src/openllm/models/auto/factory.py index c0184ce4..9603aa54 100644 --- a/src/openllm/models/auto/factory.py +++ b/src/openllm/models/auto/factory.py @@ -21,35 +21,40 @@ import inflection import openllm from ...utils import ReprMixin if t.TYPE_CHECKING: - import types - from ..._llm import LLMRunner - from collections import _odict_items, _odict_keys, _odict_values - ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] - ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] - ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] - ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] + import types + from ..._llm import LLMRunner + from collections import _odict_items, _odict_keys, _odict_values + ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] + ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] + ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] + ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] else: - ConfigModelKeysView = ConfigModelValuesView = ConfigModelItemsView = t.Any - ConfigModelOrderedDict = OrderedDict + ConfigModelKeysView = ConfigModelValuesView = ConfigModelItemsView = t.Any + ConfigModelOrderedDict = OrderedDict logger = logging.getLogger(__name__) + class BaseAutoLLMClass: - _model_mapping: _LazyAutoMapping - def __init__(self, *args: t.Any, **attrs: t.Any): raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.") - @classmethod - def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]: - """The lower level API for creating a LLM instance. + _model_mapping: _LazyAutoMapping + + def __init__(self, *args: t.Any, **attrs: t.Any): + raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.") + + @classmethod + def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]: + """The lower level API for creating a LLM instance. ```python >>> import openllm >>> llm = openllm.AutoLLM.for_model("flan-t5") ``` """ - llm = cls.infer_class_from_name(model).from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs) - if ensure_available: llm.ensure_model_id_exists() - return llm - @classmethod - def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]: - """Create a LLM Runner for the given model name. + llm = cls.infer_class_from_name(model).from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs) + if ensure_available: llm.ensure_model_id_exists() + return llm + + @classmethod + def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]: + """Create a LLM Runner for the given model name. Args: model: The model name to instantiate. @@ -59,76 +64,107 @@ class BaseAutoLLMClass: Returns: A LLM instance. """ - runner_kwargs_name = set(inspect.signature(openllm.LLM[t.Any, t.Any].to_runner).parameters) - runner_attrs = {k: v for k, v in attrs.items() if k in runner_kwargs_name} - for k in runner_attrs: del attrs[k] - return cls.for_model(model, model_id=model_id, **attrs).to_runner(**runner_attrs) - @classmethod - def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]): - """Register a new model for this class. + runner_kwargs_name = set(inspect.signature(openllm.LLM[t.Any, t.Any].to_runner).parameters) + runner_attrs = {k: v for k, v in attrs.items() if k in runner_kwargs_name} + for k in runner_attrs: + del attrs[k] + return cls.for_model(model, model_id=model_id, **attrs).to_runner(**runner_attrs) + + @classmethod + def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]): + """Register a new model for this class. Args: config_class: The configuration corresponding to the model to register. llm_class: The runnable to register. """ - if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class: raise ValueError(f"The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!") - cls._model_mapping.register(config_class, llm_class) - @classmethod - def infer_class_from_name(cls, name: str) -> type[openllm.LLM[t.Any, t.Any]]: - config_class = openllm.AutoConfig.infer_class_from_name(name) - if config_class in cls._model_mapping: return cls._model_mapping[config_class] - raise ValueError(f"Unrecognized configuration class ({config_class}) for {name}. Model name should be one of {', '.join(openllm.CONFIG_MAPPING.keys())} (Registered configuration class: {', '.join([i.__name__ for i in cls._model_mapping.keys()])}).") + if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class: + raise ValueError(f"The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!") + cls._model_mapping.register(config_class, llm_class) + + @classmethod + def infer_class_from_name(cls, name: str) -> type[openllm.LLM[t.Any, t.Any]]: + config_class = openllm.AutoConfig.infer_class_from_name(name) + if config_class in cls._model_mapping: return cls._model_mapping[config_class] + raise ValueError(f"Unrecognized configuration class ({config_class}) for {name}. Model name should be one of {', '.join(openllm.CONFIG_MAPPING.keys())} (Registered configuration class: {', '.join([i.__name__ for i in cls._model_mapping.keys()])}).") + def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any: - if attr is None: return - if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr) - if hasattr(module, attr): return getattr(module, attr) - # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the object at the top level. - openllm_module = importlib.import_module("openllm") - if module != openllm_module: - try: return getattribute_from_module(openllm_module, attr) - except ValueError: raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None - raise ValueError(f"Could not find {attr} in {openllm_module}!") + if attr is None: return + if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr) + if hasattr(module, attr): return getattr(module, attr) + # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the object at the top level. + openllm_module = importlib.import_module("openllm") + if module != openllm_module: + try: + return getattribute_from_module(openllm_module, attr) + except ValueError: + raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None + raise ValueError(f"Could not find {attr} in {openllm_module}!") + class _LazyAutoMapping(ConfigModelOrderedDict, ReprMixin): - """Based on transformers.models.auto.configuration_auto._LazyAutoMapping. + """Based on transformers.models.auto.configuration_auto._LazyAutoMapping. This OrderedDict values() and keys() returns the list instead, so you don't have to do list(mapping.values()) to get the list of values. """ - def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]): - self._config_mapping = config_mapping - self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} - self._model_mapping = model_mapping - self._extra_content: dict[t.Any, t.Any] = {} - self._modules: dict[str, types.ModuleType] = {} - def __len__(self): return len(set(self._config_mapping.keys()).intersection(self._model_mapping.keys())) + len(self._extra_content) - def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]: - if key in self._extra_content: return self._extra_content[key] - model_type = self._reverse_config_mapping[key.__name__] - if model_type in self._model_mapping: return self._load_attr_from_module(model_type, self._model_mapping[model_type]) - # Maybe there was several model types associated with this config. - model_types = [k for k, v in self._config_mapping.items() if v == key.__name__] - for mtype in model_types: - if mtype in self._model_mapping: return self._load_attr_from_module(mtype, self._model_mapping[mtype]) - raise KeyError(key) - def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any: - module_name = inflection.underscore(model_type) - if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models") - return getattribute_from_module(self._modules[module_name], attr) - def keys(self): return t.cast(ConfigModelKeysView, [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys())) - @property - def __repr_keys__(self) -> set[str]: return set(self._config_mapping.keys()) - def __repr__(self) -> str: return ReprMixin.__repr__(self) - def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]: yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping) - def __bool__(self): return bool(self.keys()) - def values(self): return t.cast(ConfigModelValuesView, [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(self._extra_content.values())) - def items(self): return t.cast(ConfigModelItemsView, [(self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items())) - def __iter__(self): return iter(self.keys()) - def __contains__(self, item: t.Any): - if item in self._extra_content: return True - if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: return False - return self._reverse_config_mapping[item.__name__] in self._model_mapping - def register(self, key: t.Any, value: t.Any): - if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping: - if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.") - self._extra_content[key] = value + def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]): + self._config_mapping = config_mapping + self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} + self._model_mapping = model_mapping + self._extra_content: dict[t.Any, t.Any] = {} + self._modules: dict[str, types.ModuleType] = {} + + def __len__(self): + return len(set(self._config_mapping.keys()).intersection(self._model_mapping.keys())) + len(self._extra_content) + + def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]: + if key in self._extra_content: return self._extra_content[key] + model_type = self._reverse_config_mapping[key.__name__] + if model_type in self._model_mapping: return self._load_attr_from_module(model_type, self._model_mapping[model_type]) + # Maybe there was several model types associated with this config. + model_types = [k for k, v in self._config_mapping.items() if v == key.__name__] + for mtype in model_types: + if mtype in self._model_mapping: return self._load_attr_from_module(mtype, self._model_mapping[mtype]) + raise KeyError(key) + + def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any: + module_name = inflection.underscore(model_type) + if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models") + return getattribute_from_module(self._modules[module_name], attr) + + def keys(self): + return t.cast(ConfigModelKeysView, [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys())) + + @property + def __repr_keys__(self) -> set[str]: + return set(self._config_mapping.keys()) + + def __repr__(self) -> str: + return ReprMixin.__repr__(self) + + def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]: + yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping) + + def __bool__(self): + return bool(self.keys()) + + def values(self): + return t.cast(ConfigModelValuesView, [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(self._extra_content.values())) + + def items(self): + return t.cast(ConfigModelItemsView, [(self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items())) + + def __iter__(self): + return iter(self.keys()) + + def __contains__(self, item: t.Any): + if item in self._extra_content: return True + if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: return False + return self._reverse_config_mapping[item.__name__] in self._model_mapping + + def register(self, key: t.Any, value: t.Any): + if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping: + if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.") + self._extra_content[key] = value + __all__ = ["BaseAutoLLMClass", "_LazyAutoMapping"] diff --git a/src/openllm/models/auto/modeling_auto.py b/src/openllm/models/auto/modeling_auto.py index aa5e08a3..51930569 100644 --- a/src/openllm/models/auto/modeling_auto.py +++ b/src/openllm/models/auto/modeling_auto.py @@ -17,21 +17,9 @@ from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping -MODEL_MAPPING_NAMES = OrderedDict( - [ - ("chatglm", "ChatGLM"), - ("dolly_v2", "DollyV2"), - ("falcon", "Falcon"), - ("flan_t5", "FlanT5"), - ("gpt_neox", "GPTNeoX"), - ("llama", "Llama"), - ("mpt", "MPT"), - ("opt", "OPT"), - ("stablelm", "StableLM"), - ("starcoder", "StarCoder"), - ("baichuan", "Baichuan"), - ] -) + +MODEL_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLM"), ("dolly_v2", "DollyV2"), ("falcon", "Falcon"), ("flan_t5", "FlanT5"), ("gpt_neox", "GPTNeoX"), ("llama", "Llama"), ("mpt", "MPT"), ("opt", "OPT"), ("stablelm", "StableLM"), ("starcoder", "StarCoder"), ("baichuan", "Baichuan"),]) MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) + class AutoLLM(BaseAutoLLMClass): - _model_mapping = MODEL_MAPPING + _model_mapping = MODEL_MAPPING diff --git a/src/openllm/models/auto/modeling_flax_auto.py b/src/openllm/models/auto/modeling_flax_auto.py index 18f489cf..6be4646e 100644 --- a/src/openllm/models/auto/modeling_flax_auto.py +++ b/src/openllm/models/auto/modeling_flax_auto.py @@ -16,12 +16,9 @@ from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping -MODEL_FLAX_MAPPING_NAMES = OrderedDict( - [ - ("flan_t5", "FlaxFlanT5"), - ("opt", "FlaxOPT"), - ] -) + +MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT"),]) MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES) + class AutoFlaxLLM(BaseAutoLLMClass): - _model_mapping = MODEL_FLAX_MAPPING + _model_mapping = MODEL_FLAX_MAPPING diff --git a/src/openllm/models/auto/modeling_tf_auto.py b/src/openllm/models/auto/modeling_tf_auto.py index fd90f56b..2439b59b 100644 --- a/src/openllm/models/auto/modeling_tf_auto.py +++ b/src/openllm/models/auto/modeling_tf_auto.py @@ -16,12 +16,9 @@ from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping -MODEL_TF_MAPPING_NAMES = OrderedDict( - [ - ("flan_t5", "TFFlanT5"), - ("opt", "TFOPT"), - ] -) + +MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5"), ("opt", "TFOPT"),]) MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES) + class AutoTFLLM(BaseAutoLLMClass): - _model_mapping = MODEL_TF_MAPPING + _model_mapping = MODEL_TF_MAPPING diff --git a/src/openllm/models/auto/modeling_vllm_auto.py b/src/openllm/models/auto/modeling_vllm_auto.py index ccabd0df..751836a8 100644 --- a/src/openllm/models/auto/modeling_vllm_auto.py +++ b/src/openllm/models/auto/modeling_vllm_auto.py @@ -16,12 +16,9 @@ from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping -MODEL_VLLM_MAPPING_NAMES = OrderedDict( - [ - ("llama", "VLLMLlama"), - ("opt", "VLLMOPT") - ] -) + +MODEL_VLLM_MAPPING_NAMES = OrderedDict([("llama", "VLLMLlama"), ("opt", "VLLMOPT")]) MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES) + class AutoVLLM(BaseAutoLLMClass): - _model_mapping = MODEL_VLLM_MAPPING + _model_mapping = MODEL_VLLM_MAPPING diff --git a/src/openllm/models/baichuan/__init__.py b/src/openllm/models/baichuan/__init__.py index 3a210e92..fc4e60f1 100644 --- a/src/openllm/models/baichuan/__init__.py +++ b/src/openllm/models/baichuan/__init__.py @@ -18,18 +18,24 @@ from ...exceptions import MissingDependencyError from ...utils import LazyModule from ...utils import is_cpm_kernels_available from ...utils import is_torch_available + _import_structure: dict[str, list[str]] = {"configuration_baichuan": ["BaichuanConfig", "START_BAICHUAN_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]} try: - if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_baichuan"] = ["Baichuan"] + if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_baichuan"] = ["Baichuan"] if t.TYPE_CHECKING: - from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_baichuan import START_BAICHUAN_COMMAND_DOCSTRING as START_BAICHUAN_COMMAND_DOCSTRING - from .configuration_baichuan import BaichuanConfig as BaichuanConfig + from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_baichuan import START_BAICHUAN_COMMAND_DOCSTRING as START_BAICHUAN_COMMAND_DOCSTRING + from .configuration_baichuan import BaichuanConfig as BaichuanConfig - try: - if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_baichuan import Baichuan as Baichuan -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + try: + if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_baichuan import Baichuan as Baichuan +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/baichuan/configuration_baichuan.py b/src/openllm/models/baichuan/configuration_baichuan.py index a1fe5c95..2ed989a1 100644 --- a/src/openllm/models/baichuan/configuration_baichuan.py +++ b/src/openllm/models/baichuan/configuration_baichuan.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations import openllm + class BaichuanConfig(openllm.LLMConfig): - """Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology. + """Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology. Baichuan-7B is based on Transformer architecture, which contains 7 billion parameters and trained on approximately 1.2 trillion tokens. @@ -23,28 +24,16 @@ class BaichuanConfig(openllm.LLMConfig): and English benchmarks (C-Eval, MMLU, etc). Refer to [Baichuan-7B's GitHub page](https://github.com/baichuan-inc/Baichuan-7B) for more information. """ - __config__ = { - "name_type": "lowercase", - "trust_remote_code": True, - "timeout": 3600000, - "requires_gpu": True, - "url": "https://github.com/baichuan-inc/Baichuan-7B", - "requirements": ["cpm-kernels", "sentencepiece"], - "architecture": "BaiChuanForCausalLM", - "default_id": "baichuan-inc/baichuan-7b", - "model_ids": [ - "baichuan-inc/baichuan-7b", - "baichuan-inc/baichuan-13b-base", - "baichuan-inc/baichuan-13b-chat", - "fireballoon/baichuan-vicuna-chinese-7b", - "fireballoon/baichuan-vicuna-7b", - "hiyouga/baichuan-7b-sft", - ], - } - class GenerationConfig: - max_new_tokens: int = 2048 - top_p: float = 0.7 - temperature: float = 0.95 + __config__ = { + "name_type": "lowercase", "trust_remote_code": True, "timeout": 3600000, "requires_gpu": True, "url": "https://github.com/baichuan-inc/Baichuan-7B", "requirements": ["cpm-kernels", "sentencepiece"], "architecture": "BaiChuanForCausalLM", "default_id": "baichuan-inc/baichuan-7b", + "model_ids": ["baichuan-inc/baichuan-7b", "baichuan-inc/baichuan-13b-base", "baichuan-inc/baichuan-13b-chat", "fireballoon/baichuan-vicuna-chinese-7b", "fireballoon/baichuan-vicuna-7b", "hiyouga/baichuan-7b-sft",], + } + + class GenerationConfig: + max_new_tokens: int = 2048 + top_p: float = 0.7 + temperature: float = 0.95 + START_BAICHUAN_COMMAND_DOCSTRING = """\ Run a LLMServer for Baichuan model. diff --git a/src/openllm/models/baichuan/modeling_baichuan.py b/src/openllm/models/baichuan/modeling_baichuan.py index 3942b867..dedeeb37 100644 --- a/src/openllm/models/baichuan/modeling_baichuan.py +++ b/src/openllm/models/baichuan/modeling_baichuan.py @@ -18,12 +18,18 @@ from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE from ..._prompt import process_prompt if t.TYPE_CHECKING: import torch, transformers else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") + class Baichuan(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]): - __openllm_internal__ = True - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, **attrs}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) - with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): - outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config()) - return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + __openllm_internal__ = True + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, **attrs}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: + return generation_result[0] + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): + outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config()) + return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) diff --git a/src/openllm/models/chatglm/__init__.py b/src/openllm/models/chatglm/__init__.py index 00457a4c..73d0d59c 100644 --- a/src/openllm/models/chatglm/__init__.py +++ b/src/openllm/models/chatglm/__init__.py @@ -18,17 +18,23 @@ from ...exceptions import MissingDependencyError from ...utils import LazyModule from ...utils import is_cpm_kernels_available from ...utils import is_torch_available + _import_structure: dict[str, list[str]] = {"configuration_chatglm": ["ChatGLMConfig", "START_CHATGLM_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]} try: - if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_chatglm"] = ["ChatGLM"] + if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_chatglm"] = ["ChatGLM"] if t.TYPE_CHECKING: - from .configuration_chatglm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_chatglm import START_CHATGLM_COMMAND_DOCSTRING as START_CHATGLM_COMMAND_DOCSTRING - from .configuration_chatglm import ChatGLMConfig as ChatGLMConfig - try: - if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_chatglm import ChatGLM as ChatGLM -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_chatglm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_chatglm import START_CHATGLM_COMMAND_DOCSTRING as START_CHATGLM_COMMAND_DOCSTRING + from .configuration_chatglm import ChatGLMConfig as ChatGLMConfig + try: + if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_chatglm import ChatGLM as ChatGLM +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/chatglm/configuration_chatglm.py b/src/openllm/models/chatglm/configuration_chatglm.py index 711e7d1b..4f5777b0 100644 --- a/src/openllm/models/chatglm/configuration_chatglm.py +++ b/src/openllm/models/chatglm/configuration_chatglm.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations import openllm + class ChatGLMConfig(openllm.LLMConfig): - """ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework. + """ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). @@ -27,34 +28,20 @@ class ChatGLMConfig(openllm.LLMConfig): Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information. """ - __config__ = { - "name_type": "lowercase", - "trust_remote_code": True, - "timeout": 3600000, - "requires_gpu": True, - "url": "https://github.com/THUDM/ChatGLM-6B", - "requirements": ["cpm-kernels", "sentencepiece"], - "architecture": "ChatGLMForConditionalGeneration", - "default_id": "thudm/chatglm-6b", - "model_ids": [ - "thudm/chatglm-6b", - "thudm/chatglm-6b-int8", - "thudm/chatglm-6b-int4", - "thudm/chatglm2-6b", - "thudm/chatglm2-6b-int4", - ], - } - retain_history: bool = openllm.LLMConfig.Field( - False, - description="""Whether to retain history given to the model. - If set to True, then the model will retain given history.""", - ) - use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.") - class GenerationConfig: - max_new_tokens: int = 2048 - num_beams: int = 1 - top_p: float = 0.7 - temperature: float = 0.95 + __config__ = { + "name_type": "lowercase", "trust_remote_code": True, "timeout": 3600000, "requires_gpu": True, "url": "https://github.com/THUDM/ChatGLM-6B", "requirements": ["cpm-kernels", "sentencepiece"], "architecture": "ChatGLMForConditionalGeneration", "default_id": "thudm/chatglm-6b", + "model_ids": ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4", "thudm/chatglm2-6b", "thudm/chatglm2-6b-int4",], + } + retain_history: bool = openllm.LLMConfig.Field(False, description="""Whether to retain history given to the model. + If set to True, then the model will retain given history.""",) + use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.") + + class GenerationConfig: + max_new_tokens: int = 2048 + num_beams: int = 1 + top_p: float = 0.7 + temperature: float = 0.95 + START_CHATGLM_COMMAND_DOCSTRING = """\ Run a LLMServer for ChatGLM model. diff --git a/src/openllm/models/chatglm/modeling_chatglm.py b/src/openllm/models/chatglm/modeling_chatglm.py index 91bac01a..ed422852 100644 --- a/src/openllm/models/chatglm/modeling_chatglm.py +++ b/src/openllm/models/chatglm/modeling_chatglm.py @@ -17,38 +17,45 @@ import openllm from ..._llm import LLMEmbeddings if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional") + class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerFast"]): - __openllm_internal__ = True - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, num_beams: int | None = None, top_p: float | None = None, temperature: float | None = None, chat_history: list[str] | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - prompt_text = "" - if use_default_prompt_template and chat_history is not None: - for i, (old_query, response) in enumerate(chat_history): prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" # noqa: RUF001 - prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" # noqa: RUF001 - else: prompt_text = prompt - postprocess_generate_kwargs = {"chat_history": chat_history if chat_history is not None else None} - # NOTE: The rest of attrs should be kwargs for GenerationConfig - generate_kwargs = {"max_new_tokens": max_new_tokens, "num_beams": num_beams, "top_p": top_p, "temperature": temperature, **attrs} - return prompt_text, generate_kwargs, postprocess_generate_kwargs - def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any): - generated, history = generation_result - if self.config.retain_history: - assert chat_history is not None, "'retain_history' is True while there is no history provided." - chat_history.extend(history) - return generated - def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, str]]]: - with torch.inference_mode(): - self.model.eval() - # Only use half precision if the model is not yet quantized - if self.config.use_half_precision: self.model.half() - return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config()) - def embeddings(self, prompts: list[str]) -> LLMEmbeddings: - embeddings: list[list[float]] = [] - num_tokens = 0 - for prompt in prompts: - input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device) - with torch.inference_mode(): - outputs = self.model(input_ids, output_hidden_states=True) - data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0) - embeddings.append(data.tolist()) - num_tokens += len(input_ids[0]) - return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens) + __openllm_internal__ = True + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, num_beams: int | None = None, top_p: float | None = None, temperature: float | None = None, chat_history: list[str] | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + prompt_text = "" + if use_default_prompt_template and chat_history is not None: + for i, (old_query, response) in enumerate(chat_history): + prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" # noqa: RUF001 + prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" # noqa: RUF001 + else: + prompt_text = prompt + postprocess_generate_kwargs = {"chat_history": chat_history if chat_history is not None else None} + # NOTE: The rest of attrs should be kwargs for GenerationConfig + generate_kwargs = {"max_new_tokens": max_new_tokens, "num_beams": num_beams, "top_p": top_p, "temperature": temperature, **attrs} + return prompt_text, generate_kwargs, postprocess_generate_kwargs + + def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any): + generated, history = generation_result + if self.config.retain_history: + assert chat_history is not None, "'retain_history' is True while there is no history provided." + chat_history.extend(history) + return generated + + def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, str]]]: + with torch.inference_mode(): + self.model.eval() + # Only use half precision if the model is not yet quantized + if self.config.use_half_precision: self.model.half() + return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config()) + + def embeddings(self, prompts: list[str]) -> LLMEmbeddings: + embeddings: list[list[float]] = [] + num_tokens = 0 + for prompt in prompts: + input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device) + with torch.inference_mode(): + outputs = self.model(input_ids, output_hidden_states=True) + data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0) + embeddings.append(data.tolist()) + num_tokens += len(input_ids[0]) + return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens) diff --git a/src/openllm/models/dolly_v2/__init__.py b/src/openllm/models/dolly_v2/__init__.py index 7f5a851c..e5d00318 100644 --- a/src/openllm/models/dolly_v2/__init__.py +++ b/src/openllm/models/dolly_v2/__init__.py @@ -17,17 +17,23 @@ import typing as t from ...exceptions import MissingDependencyError from ...utils import LazyModule from ...utils import is_torch_available + _import_structure: dict[str, list[str]] = {"configuration_dolly_v2": ["DollyV2Config", "START_DOLLY_V2_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]} try: - if not is_torch_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_dolly_v2"] = ["DollyV2"] + if not is_torch_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_dolly_v2"] = ["DollyV2"] if t.TYPE_CHECKING: - from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_dolly_v2 import START_DOLLY_V2_COMMAND_DOCSTRING as START_DOLLY_V2_COMMAND_DOCSTRING - from .configuration_dolly_v2 import DollyV2Config as DollyV2Config - try: - if not is_torch_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_dolly_v2 import DollyV2 as DollyV2 -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_dolly_v2 import START_DOLLY_V2_COMMAND_DOCSTRING as START_DOLLY_V2_COMMAND_DOCSTRING + from .configuration_dolly_v2 import DollyV2Config as DollyV2Config + try: + if not is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_dolly_v2 import DollyV2 as DollyV2 +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/dolly_v2/configuration_dolly_v2.py b/src/openllm/models/dolly_v2/configuration_dolly_v2.py index e0eba513..20131f82 100644 --- a/src/openllm/models/dolly_v2/configuration_dolly_v2.py +++ b/src/openllm/models/dolly_v2/configuration_dolly_v2.py @@ -15,8 +15,9 @@ from __future__ import annotations import typing as t import openllm if t.TYPE_CHECKING: import transformers + class DollyV2Config(openllm.LLMConfig): - """Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use. + """Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use. Based on pythia-12b, Dolly is trained on ~15k instruction/response fine tuning records databricks-dolly-15k generated by Databricks employees in capability domains from the InstructGPT paper, including brainstorming, @@ -27,22 +28,16 @@ class DollyV2Config(openllm.LLMConfig): Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information. """ - __config__ = { - "timeout": 3600000, - "url": "https://github.com/databrickslabs/dolly", - "architecture": "GPTNeoXForCausalLM", - "default_id": "databricks/dolly-v2-3b", - "model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"], - } - return_full_text: bool = openllm.LLMConfig.Field( - False, description="Whether to return the full prompt to the users." - ) - class GenerationConfig: - temperature: float = 0.9 - top_p: float = 0.92 - top_k: int = 5 - max_new_tokens: int = 256 - eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY) + __config__ = {"timeout": 3600000, "url": "https://github.com/databrickslabs/dolly", "architecture": "GPTNeoXForCausalLM", "default_id": "databricks/dolly-v2-3b", "model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"],} + return_full_text: bool = openllm.LLMConfig.Field(False, description="Whether to return the full prompt to the users.") + + class GenerationConfig: + temperature: float = 0.9 + top_p: float = 0.92 + top_k: int = 5 + max_new_tokens: int = 256 + eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY) + START_DOLLY_V2_COMMAND_DOCSTRING = """\ Run a LLMServer for dolly-v2 model. @@ -74,8 +69,9 @@ DEFAULT_PROMPT_TEMPLATE = """{intro} {instruction} {response_key} """.format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY) + def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int: - """Gets the token ID for a given string that has been added to the tokenizer as a special token. + """Gets the token ID for a given string that has been added to the tokenizer as a special token. When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to. @@ -90,6 +86,6 @@ def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) Returns: int: the token ID for the given key. """ - token_ids = tokenizer.encode(key) - if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}") - return token_ids[0] + token_ids = tokenizer.encode(key) + if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}") + return token_ids[0] diff --git a/src/openllm/models/dolly_v2/modeling_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_dolly_v2.py index 62c7219a..96963478 100644 --- a/src/openllm/models/dolly_v2/modeling_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_dolly_v2.py @@ -24,104 +24,130 @@ from ..._prompt import process_prompt if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf else: torch, transformers, tf = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("tf", globals(), "tensorflow") logger = logging.getLogger(__name__) -@t.overload -def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline: ... -@t.overload -def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]: ... -def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline: - # Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information. - class InstructionTextGenerationPipeline(transformers.Pipeline): - def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs) - def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any): - if t.TYPE_CHECKING: assert self.tokenizer is not None - preprocess_params: dict[str, t.Any] = {} - # newer versions of the tokenizer configure the response key as a special token. newer versions still may - # append a newline to yield a single token. find whatever token is configured for the response key. - tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None) - response_key_token_id = None - end_key_token_id = None - if tokenizer_response_key: - try: - response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key) - end_key_token_id = get_special_token_id(self.tokenizer, END_KEY) - # Ensure generation stops once it generates "### End" - generate_kwargs["eos_token_id"] = end_key_token_id - except ValueError: pass - forward_params = generate_kwargs - postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id} - if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text - return preprocess_params, forward_params, postprocess_params - def preprocess(self, input_: str, **generate_kwargs: t.Any): - if t.TYPE_CHECKING: assert self.tokenizer is not None - prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_) - inputs = self.tokenizer(prompt_text, return_tensors="pt") - inputs["prompt_text"] = prompt_text - inputs["instruction_text"] = input_ - return inputs - def _forward(self, model_inputs: dict[str, t.Any], **generate_kwargs: t.Any): - if t.TYPE_CHECKING: assert self.tokenizer is not None - input_ids, attention_mask = model_inputs["input_ids"], model_inputs.get("attention_mask", None) - if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1 - else: in_b = input_ids.shape[0] - generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs) - out_b = generated_sequence.shape[0] - if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) - elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) - instruction_text = model_inputs.pop("instruction_text") - return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text} - def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False): - if t.TYPE_CHECKING: assert self.tokenizer is not None - generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"] - generated_sequence: list[list[int]] = generated_sequence.numpy().tolist() - records: list[dict[t.Literal["generated_text"], str]] = [] - for sequence in generated_sequence: - # The response will be set to this variable if we can identify it. - decoded = None - # If we have token IDs for the response and end, then we can find the tokens and only decode between them. - if response_key_token_id and end_key_token_id: - # Find where "### Response:" is first found in the generated tokens. Considering this is part of the - # prompt, we should definitely find it. We will return the tokens found after this token. - try: response_pos = sequence.index(response_key_token_id) - except ValueError: response_pos = None - if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence) - if response_pos: - # Next find where "### End" is located. The model has been trained to end its responses with this - # sequence (or actually, the token ID it maps to, since it is a special token). We may not find - # this token, as the response could be truncated. If we don't find it then just return everything - # to the end. Note that even though we set eos_token_id, we still see the this token at the end. - try: end_pos = sequence.index(end_key_token_id) - except ValueError: end_pos = None - decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip() - if not decoded: - # Otherwise we'll decode everything and use a regex to find the response and end. - fully_decoded = self.tokenizer.decode(sequence) - # The response appears after "### Response:". The model has been trained to append "### End" at the - # end. - m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL) - if m: decoded = m.group(1).strip() - else: - # The model might not generate the "### End" sequence before reaching the max tokens. In this case, - # return everything after "### Response:". - m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL) - if m: decoded = m.group(1).strip() - else: logger.warning("Failed to find response in:\n%s", fully_decoded) - # If the full text is requested, then append the decoded text to the original instruction. - # This technically isn't the full text, as we format the instruction in the prompt the model has been - # trained on, but to the client it will appear to be the full text. - if return_full_text: decoded = f"{instruction_text}\n{decoded}" - rec = {"generated_text": decoded} - records.append(rec) - return records - if _init: return InstructionTextGenerationPipeline() - return InstructionTextGenerationPipeline +@t.overload +def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline: + ... + +@t.overload +def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]: + ... + +def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline: + # Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information. + class InstructionTextGenerationPipeline(transformers.Pipeline): + def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): + super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs) + + def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any): + if t.TYPE_CHECKING: assert self.tokenizer is not None + preprocess_params: dict[str, t.Any] = {} + # newer versions of the tokenizer configure the response key as a special token. newer versions still may + # append a newline to yield a single token. find whatever token is configured for the response key. + tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None) + response_key_token_id = None + end_key_token_id = None + if tokenizer_response_key: + try: + response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key) + end_key_token_id = get_special_token_id(self.tokenizer, END_KEY) + # Ensure generation stops once it generates "### End" + generate_kwargs["eos_token_id"] = end_key_token_id + except ValueError: + pass + forward_params = generate_kwargs + postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id} + if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text + return preprocess_params, forward_params, postprocess_params + + def preprocess(self, input_: str, **generate_kwargs: t.Any): + if t.TYPE_CHECKING: assert self.tokenizer is not None + prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_) + inputs = self.tokenizer(prompt_text, return_tensors="pt") + inputs["prompt_text"] = prompt_text + inputs["instruction_text"] = input_ + return inputs + + def _forward(self, model_inputs: dict[str, t.Any], **generate_kwargs: t.Any): + if t.TYPE_CHECKING: assert self.tokenizer is not None + input_ids, attention_mask = model_inputs["input_ids"], model_inputs.get("attention_mask", None) + if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1 + else: in_b = input_ids.shape[0] + generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs) + out_b = generated_sequence.shape[0] + if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) + elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) + instruction_text = model_inputs.pop("instruction_text") + return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text} + + def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False): + if t.TYPE_CHECKING: assert self.tokenizer is not None + generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"] + generated_sequence: list[list[int]] = generated_sequence.numpy().tolist() + records: list[dict[t.Literal["generated_text"], str]] = [] + for sequence in generated_sequence: + # The response will be set to this variable if we can identify it. + decoded = None + # If we have token IDs for the response and end, then we can find the tokens and only decode between them. + if response_key_token_id and end_key_token_id: + # Find where "### Response:" is first found in the generated tokens. Considering this is part of the + # prompt, we should definitely find it. We will return the tokens found after this token. + try: + response_pos = sequence.index(response_key_token_id) + except ValueError: + response_pos = None + if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence) + if response_pos: + # Next find where "### End" is located. The model has been trained to end its responses with this + # sequence (or actually, the token ID it maps to, since it is a special token). We may not find + # this token, as the response could be truncated. If we don't find it then just return everything + # to the end. Note that even though we set eos_token_id, we still see the this token at the end. + try: + end_pos = sequence.index(end_key_token_id) + except ValueError: + end_pos = None + decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip() + if not decoded: + # Otherwise we'll decode everything and use a regex to find the response and end. + fully_decoded = self.tokenizer.decode(sequence) + # The response appears after "### Response:". The model has been trained to append "### End" at the + # end. + m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL) + if m: decoded = m.group(1).strip() + else: + # The model might not generate the "### End" sequence before reaching the max tokens. In this case, + # return everything after "### Response:". + m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL) + if m: decoded = m.group(1).strip() + else: logger.warning("Failed to find response in:\n%s", fully_decoded) + # If the full text is requested, then append the decoded text to the original instruction. + # This technically isn't the full text, as we format the instruction in the prompt the model has been + # trained on, but to the client it will appear to be the full text. + if return_full_text: decoded = f"{instruction_text}\n{decoded}" + rec = {"generated_text": decoded} + records.append(rec) + return records + + if _init: return InstructionTextGenerationPipeline() + return InstructionTextGenerationPipeline + class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedTokenizer"]): - __openllm_internal__ = True - @property - def import_kwargs(self): return {"device_map": "auto" if torch.cuda.is_available() else None, "torch_dtype": torch.bfloat16}, {"padding_side": "left"} - def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return get_pipeline(model=transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), tokenizer=self.tokenizer, _init=True, return_full_text=self.config.return_full_text) - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {} - def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str: return generation_result[0]["generated_text"] - def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]: - llm_config = self.config.model_construct_env(**attrs) - with torch.inference_mode(): return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config()) + __openllm_internal__ = True + + @property + def import_kwargs(self): + return {"device_map": "auto" if torch.cuda.is_available() else None, "torch_dtype": torch.bfloat16}, {"padding_side": "left"} + + def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: + return get_pipeline(model=transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), tokenizer=self.tokenizer, _init=True, return_full_text=self.config.return_full_text) + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {} + + def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str: + return generation_result[0]["generated_text"] + + def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]: + llm_config = self.config.model_construct_env(**attrs) + with torch.inference_mode(): + return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config()) diff --git a/src/openllm/models/falcon/__init__.py b/src/openllm/models/falcon/__init__.py index 462f518a..61729c6e 100644 --- a/src/openllm/models/falcon/__init__.py +++ b/src/openllm/models/falcon/__init__.py @@ -17,17 +17,23 @@ import typing as t from ...exceptions import MissingDependencyError from ...utils import LazyModule from ...utils import is_torch_available + _import_structure: dict[str, list[str]] = {"configuration_falcon": ["FalconConfig", "START_FALCON_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]} try: - if not is_torch_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_falcon"] = ["Falcon"] + if not is_torch_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_falcon"] = ["Falcon"] if t.TYPE_CHECKING: - from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_falcon import START_FALCON_COMMAND_DOCSTRING as START_FALCON_COMMAND_DOCSTRING - from .configuration_falcon import FalconConfig as FalconConfig - try: - if not is_torch_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_falcon import Falcon as Falcon -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_falcon import START_FALCON_COMMAND_DOCSTRING as START_FALCON_COMMAND_DOCSTRING + from .configuration_falcon import FalconConfig as FalconConfig + try: + if not is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_falcon import Falcon as Falcon +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/falcon/configuration_falcon.py b/src/openllm/models/falcon/configuration_falcon.py index 176cc95e..1bb24a61 100644 --- a/src/openllm/models/falcon/configuration_falcon.py +++ b/src/openllm/models/falcon/configuration_falcon.py @@ -13,45 +13,26 @@ # limitations under the License. from __future__ import annotations import openllm + class FalconConfig(openllm.LLMConfig): - """Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora. + """Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora. It is made available under the TII Falcon LLM License. Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information. """ - __config__ = { - "name_type": "lowercase", - "trust_remote_code": True, - "requires_gpu": True, - "timeout": int(36e6), - "url": "https://falconllm.tii.ae/", - "requirements": ["einops", "xformers"], - "architecture": "FalconForCausalLM", - "default_id": "tiiuae/falcon-7b", - "model_ids": [ - "tiiuae/falcon-7b", - "tiiuae/falcon-40b", - "tiiuae/falcon-7b-instruct", - "tiiuae/falcon-40b-instruct", - ], - "fine_tune_strategies": ( - { - "adapter_type": "lora", - "r": 64, - "lora_alpha": 16, - "lora_dropout": 0.1, - "bias": "none", - "target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], - }, - ), - } - class GenerationConfig: - max_new_tokens: int = 200 - top_k: int = 10 - num_return_sequences: int = 1 - num_beams: int = 4 - early_stopping: bool = True + __config__ = { + "name_type": "lowercase", "trust_remote_code": True, "requires_gpu": True, "timeout": int(36e6), "url": "https://falconllm.tii.ae/", "requirements": ["einops", "xformers"], "architecture": "FalconForCausalLM", "default_id": "tiiuae/falcon-7b", "model_ids": ["tiiuae/falcon-7b", "tiiuae/falcon-40b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct",], + "fine_tune_strategies": ({"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none", "target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],},), + } + + class GenerationConfig: + max_new_tokens: int = 200 + top_k: int = 10 + num_return_sequences: int = 1 + num_beams: int = 4 + early_stopping: bool = True + START_FALCON_COMMAND_DOCSTRING = """\ Run a LLMServer for FalconLM model. diff --git a/src/openllm/models/falcon/modeling_falcon.py b/src/openllm/models/falcon/modeling_falcon.py index cc7f2263..99be5f8d 100644 --- a/src/openllm/models/falcon/modeling_falcon.py +++ b/src/openllm/models/falcon/modeling_falcon.py @@ -18,21 +18,31 @@ from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE from ..._prompt import process_prompt if t.TYPE_CHECKING: import torch, transformers else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") + class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]): - __openllm_internal__ = True - @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() else None}, {} - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device) - with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env( eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True) - def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]: - max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device) - src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([])) - stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) - result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) - # Inference API returns the stop sequence - for stop_seq in stop: - if result.endswith(stop_seq): result = result[: -len(stop_seq)] - return [{"generated_text": result}] + __openllm_internal__ = True + + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: + return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() else None}, {} + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: + return generation_result[0] + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device) + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): + return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True) + + def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]: + max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device) + src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([])) + stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) + result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) + # Inference API returns the stop sequence + for stop_seq in stop: + if result.endswith(stop_seq): result = result[:-len(stop_seq)] + return [{"generated_text": result}] diff --git a/src/openllm/models/flan_t5/__init__.py b/src/openllm/models/flan_t5/__init__.py index 2d5c97d1..7fa152ba 100644 --- a/src/openllm/models/flan_t5/__init__.py +++ b/src/openllm/models/flan_t5/__init__.py @@ -20,33 +20,47 @@ from ...utils import LazyModule from ...utils import is_flax_available from ...utils import is_tf_available from ...utils import is_torch_available + _import_structure: dict[str, list[str]] = {"configuration_flan_t5": ["FlanT5Config", "START_FLAN_T5_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]} try: - if not is_torch_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_flan_t5"] = ["FlanT5"] + if not is_torch_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_flan_t5"] = ["FlanT5"] try: - if not is_flax_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_flax_flan_t5"] = ["FlaxFlanT5"] + if not is_flax_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_flax_flan_t5"] = ["FlaxFlanT5"] try: - if not is_tf_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_tf_flan_t5"] = ["TFFlanT5"] + if not is_tf_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_tf_flan_t5"] = ["TFFlanT5"] if t.TYPE_CHECKING: - from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_flan_t5 import START_FLAN_T5_COMMAND_DOCSTRING as START_FLAN_T5_COMMAND_DOCSTRING - from .configuration_flan_t5 import FlanT5Config as FlanT5Config - try: - if not is_torch_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_flan_t5 import FlanT5 as FlanT5 - try: - if not is_flax_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_flax_flan_t5 import FlaxFlanT5 as FlaxFlanT5 - try: - if not is_tf_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_tf_flan_t5 import TFFlanT5 as TFFlanT5 -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_flan_t5 import START_FLAN_T5_COMMAND_DOCSTRING as START_FLAN_T5_COMMAND_DOCSTRING + from .configuration_flan_t5 import FlanT5Config as FlanT5Config + try: + if not is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_flan_t5 import FlanT5 as FlanT5 + try: + if not is_flax_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_flax_flan_t5 import FlaxFlanT5 as FlaxFlanT5 + try: + if not is_tf_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_tf_flan_t5 import TFFlanT5 as TFFlanT5 +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/flan_t5/configuration_flan_t5.py b/src/openllm/models/flan_t5/configuration_flan_t5.py index 8129273b..c5781987 100644 --- a/src/openllm/models/flan_t5/configuration_flan_t5.py +++ b/src/openllm/models/flan_t5/configuration_flan_t5.py @@ -13,32 +13,23 @@ # limitations under the License. from __future__ import annotations import openllm + class FlanT5Config(openllm.LLMConfig): - """FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf). + """FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf). It is an enhanced version of T5 that has been finetuned in a mixture of tasks. Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information. """ - __config__ = { - "url": "https://huggingface.co/docs/transformers/model_doc/flan-t5", - "default_id": "google/flan-t5-large", - "architecture": "T5ForConditionalGeneration", - "model_ids": [ - "google/flan-t5-small", - "google/flan-t5-base", - "google/flan-t5-large", - "google/flan-t5-xl", - "google/flan-t5-xxl", - ], - "model_type": "seq2seq_lm", - } - class GenerationConfig: - temperature: float = 0.9 - max_new_tokens: int = 2048 - top_k: int = 50 - top_p: float = 0.4 - repetition_penalty = 1.0 + __config__ = {"url": "https://huggingface.co/docs/transformers/model_doc/flan-t5", "default_id": "google/flan-t5-large", "architecture": "T5ForConditionalGeneration", "model_ids": ["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl",], "model_type": "seq2seq_lm",} + + class GenerationConfig: + temperature: float = 0.9 + max_new_tokens: int = 2048 + top_k: int = 50 + top_p: float = 0.4 + repetition_penalty = 1.0 + START_FLAN_T5_COMMAND_DOCSTRING = """\ Run a LLMServer for FLAN-T5 model. diff --git a/src/openllm/models/flan_t5/modeling_flan_t5.py b/src/openllm/models/flan_t5/modeling_flan_t5.py index d0d8840e..61f4b01b 100644 --- a/src/openllm/models/flan_t5/modeling_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flan_t5.py @@ -19,20 +19,28 @@ from ..._llm import LLMEmbeddings from ..._prompt import process_prompt if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional") + class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformers.T5TokenizerFast"]): - __openllm_internal__ = True - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True) - def embeddings(self, prompts: list[str]) -> LLMEmbeddings: - embeddings: list[list[float]] = [] - num_tokens = 0 - for prompt in prompts: - input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device) - with torch.inference_mode(): - outputs = self.model(input_ids, decoder_input_ids=input_ids) - data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0) - embeddings.append(data.tolist()) - num_tokens += len(input_ids[0]) - return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens) + __openllm_internal__ = True + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: + return generation_result[0] + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + with torch.inference_mode(): + return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True) + + def embeddings(self, prompts: list[str]) -> LLMEmbeddings: + embeddings: list[list[float]] = [] + num_tokens = 0 + for prompt in prompts: + input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device) + with torch.inference_mode(): + outputs = self.model(input_ids, decoder_input_ids=input_ids) + data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0) + embeddings.append(data.tolist()) + num_tokens += len(input_ids[0]) + return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens) diff --git a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py index 53be8e0f..21337887 100644 --- a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py @@ -17,13 +17,18 @@ import openllm from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE from ..._prompt import process_prompt if t.TYPE_CHECKING: import transformers # noqa: F401 + class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]): - __openllm_internal__ = True - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, decoder_start_token_id: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - if decoder_start_token_id is None: decoder_start_token_id = 0 - return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "decoder_start_token_id": decoder_start_token_id}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - # NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation. - decoder_start_token_id = attrs.pop("decoder_start_token_id", 0) - return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="np")["input_ids"], do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), decoder_start_token_id=decoder_start_token_id).sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True) + __openllm_internal__ = True + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, decoder_start_token_id: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + if decoder_start_token_id is None: decoder_start_token_id = 0 + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "decoder_start_token_id": decoder_start_token_id}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: + return generation_result[0] + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + # NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation. + decoder_start_token_id = attrs.pop("decoder_start_token_id", 0) + return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="np")["input_ids"], do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), decoder_start_token_id=decoder_start_token_id).sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True) diff --git a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py index eae63692..5f109202 100644 --- a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py @@ -17,8 +17,15 @@ import openllm from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE from ..._prompt import process_prompt if t.TYPE_CHECKING: import transformers # noqa: F401 + class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]): - __openllm_internal__ = True - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="tf").input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True) + __openllm_internal__ = True + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: + return generation_result[0] + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="tf").input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True) diff --git a/src/openllm/models/gpt_neox/__init__.py b/src/openllm/models/gpt_neox/__init__.py index 96aebfaa..2c1d59df 100644 --- a/src/openllm/models/gpt_neox/__init__.py +++ b/src/openllm/models/gpt_neox/__init__.py @@ -17,17 +17,23 @@ import typing as t from ...exceptions import MissingDependencyError from ...utils import LazyModule from ...utils import is_torch_available + _import_structure: dict[str, list[str]] = {"configuration_gpt_neox": ["GPTNeoXConfig", "START_GPT_NEOX_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]} try: - if not is_torch_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_gpt_neox"] = ["GPTNeoX"] + if not is_torch_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_gpt_neox"] = ["GPTNeoX"] if t.TYPE_CHECKING: - from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_gpt_neox import START_GPT_NEOX_COMMAND_DOCSTRING as START_GPT_NEOX_COMMAND_DOCSTRING - from .configuration_gpt_neox import GPTNeoXConfig as GPTNeoXConfig - try: - if not is_torch_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_gpt_neox import GPTNeoX as GPTNeoX -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_gpt_neox import START_GPT_NEOX_COMMAND_DOCSTRING as START_GPT_NEOX_COMMAND_DOCSTRING + from .configuration_gpt_neox import GPTNeoXConfig as GPTNeoXConfig + try: + if not is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_gpt_neox import GPTNeoX as GPTNeoX +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/gpt_neox/configuration_gpt_neox.py b/src/openllm/models/gpt_neox/configuration_gpt_neox.py index 302a912a..5109eebb 100644 --- a/src/openllm/models/gpt_neox/configuration_gpt_neox.py +++ b/src/openllm/models/gpt_neox/configuration_gpt_neox.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations import openllm + class GPTNeoXConfig(openllm.LLMConfig): - """GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license. + """GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license. It is, to the best of our knowledge, the largest dense autoregressive model that has publicly available weights at the time of submission. The training and evaluation code, as well as the model weights, @@ -28,19 +29,13 @@ class GPTNeoXConfig(openllm.LLMConfig): Refer to [GPTNeoX's model card](https://huggingface.co/docs/transformers/model_doc/gpt_neox) for more information. """ - __config__ = { - "model_name": "gpt_neox", - "start_name": "gpt-neox", - "requires_gpu": True, - "architecture": "GPTNeoXForCausalLM", - "url": "https://github.com/EleutherAI/gpt-neox", - "default_id": "eleutherai/gpt-neox-20b", - "model_ids": ["eleutherai/gpt-neox-20b"], - } - use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.") - class GenerationConfig: - temperature: float = 0.9 - max_new_tokens: int = 100 + __config__ = {"model_name": "gpt_neox", "start_name": "gpt-neox", "requires_gpu": True, "architecture": "GPTNeoXForCausalLM", "url": "https://github.com/EleutherAI/gpt-neox", "default_id": "eleutherai/gpt-neox-20b", "model_ids": ["eleutherai/gpt-neox-20b"],} + use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.") + + class GenerationConfig: + temperature: float = 0.9 + max_new_tokens: int = 100 + START_GPT_NEOX_COMMAND_DOCSTRING = """\ Run a LLMServer for GPTNeoX model. diff --git a/src/openllm/models/gpt_neox/modeling_gpt_neox.py b/src/openllm/models/gpt_neox/modeling_gpt_neox.py index c639d873..72044889 100644 --- a/src/openllm/models/gpt_neox/modeling_gpt_neox.py +++ b/src/openllm/models/gpt_neox/modeling_gpt_neox.py @@ -20,15 +20,25 @@ from ..._prompt import process_prompt if t.TYPE_CHECKING: import torch, transformers else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") logger = logging.getLogger(__name__) + class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]): - __openllm_internal__ = True - def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {} - @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {} - def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0] - def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM: - model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs) - if self.config.use_half_precision: model.half() - return model - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))) + __openllm_internal__ = True + + def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {} + + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: + return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {} + + def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: + return generation_result[0] + + def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM: + model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs) + if self.config.use_half_precision: model.half() + return model + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + with torch.inference_mode(): + return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))) diff --git a/src/openllm/models/llama/__init__.py b/src/openllm/models/llama/__init__.py index 143ceea6..56e694e6 100644 --- a/src/openllm/models/llama/__init__.py +++ b/src/openllm/models/llama/__init__.py @@ -18,26 +18,36 @@ from ...exceptions import MissingDependencyError from ...utils import LazyModule from ...utils import is_torch_available from ...utils import is_vllm_available + _import_structure: dict[str, list[str]] = {"configuration_llama": ["LlamaConfig", "START_LLAMA_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE", "PROMPT_MAPPING"]} try: - if not is_vllm_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_vllm_llama"] = ["VLLMLlama"] + if not is_vllm_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_vllm_llama"] = ["VLLMLlama"] try: - if not is_torch_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_llama"] = ["Llama"] + if not is_torch_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_llama"] = ["Llama"] if t.TYPE_CHECKING: - from .configuration_llama import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_llama import PROMPT_MAPPING as PROMPT_MAPPING - from .configuration_llama import START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING - from .configuration_llama import LlamaConfig as LlamaConfig - try: - if not is_vllm_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_vllm_llama import VLLMLlama as VLLMLlama - try: - if not is_torch_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_llama import Llama as Llama -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_llama import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_llama import PROMPT_MAPPING as PROMPT_MAPPING + from .configuration_llama import START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING + from .configuration_llama import LlamaConfig as LlamaConfig + try: + if not is_vllm_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_vllm_llama import VLLMLlama as VLLMLlama + try: + if not is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_llama import Llama as Llama +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/llama/configuration_llama.py b/src/openllm/models/llama/configuration_llama.py index b7d9005d..7b442059 100644 --- a/src/openllm/models/llama/configuration_llama.py +++ b/src/openllm/models/llama/configuration_llama.py @@ -14,8 +14,9 @@ from __future__ import annotations import typing as t import openllm + class LlamaConfig(openllm.LLMConfig): - """LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. + """LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. It is a collection of foundation language models ranging from 7B to 65B parameters. @@ -26,54 +27,24 @@ class LlamaConfig(openllm.LLMConfig): Refer to [Llama's model card](https://huggingface.co/docs/transformers/main/model_doc/llama) for more information. """ - use_llama2_prompt: bool = openllm.LLMConfig.Field(True, description="Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.") - __config__ = { - "name_type": "lowercase", - "url": "https://github.com/facebookresearch/llama", - "default_id": "huggyllama/llama-7b", - "default_implementation": {"cpu": "pt", "nvidia.com/gpu": "pt"}, - "architecture": "LlamaForCausalLM", - "requirements": ["fairscale", "sentencepiece"], - "model_ids": [ - "meta-llama/Llama-2-70b-chat-hf", - "meta-llama/Llama-2-13b-chat-hf", - "meta-llama/Llama-2-7b-chat-hf", - "meta-llama/Llama-2-70b-hf", - "meta-llama/Llama-2-13b-hf", - "meta-llama/Llama-2-7b-hf", - "NousResearch/llama-2-70b-chat-hf", - "NousResearch/llama-2-13b-chat-hf", - "NousResearch/llama-2-7b-chat-hf", - "NousResearch/llama-2-70b-hf", - "NousResearch/llama-2-13b-hf", - "NousResearch/llama-2-7b-hf", - "openlm-research/open_llama_7b_v2", - "openlm-research/open_llama_3b_v2", - "openlm-research/open_llama_13b", - "huggyllama/llama-65b", - "huggyllama/llama-30b", - "huggyllama/llama-13b", - "huggyllama/llama-7b", - ], - "tokenizer_class": "LlamaTokenizerFast", - "fine_tune_strategies": ( - { - "adapter_type": "lora", - "r": 64, - "lora_alpha": 16, - "lora_dropout": 0.1, - "bias": "none", - }, - ), - } - class GenerationConfig: - max_new_tokens: int = 256 - temperature: float = 0.45 - top_p: float = 0.95 - top_k: int = 12 - class SamplingParams: - best_of: int = 1 - presence_penalty: float = 0.5 + use_llama2_prompt: bool = openllm.LLMConfig.Field(True, description="Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.") + __config__ = { + "name_type": "lowercase", "url": "https://github.com/facebookresearch/llama", "default_id": "huggyllama/llama-7b", "default_implementation": {"cpu": "pt", "nvidia.com/gpu": "pt"}, "architecture": "LlamaForCausalLM", "requirements": ["fairscale", "sentencepiece"], "model_ids": [ + "meta-llama/Llama-2-70b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-13b-hf", "meta-llama/Llama-2-7b-hf", "NousResearch/llama-2-70b-chat-hf", "NousResearch/llama-2-13b-chat-hf", "NousResearch/llama-2-7b-chat-hf", "NousResearch/llama-2-70b-hf", "NousResearch/llama-2-13b-hf", + "NousResearch/llama-2-7b-hf", "openlm-research/open_llama_7b_v2", "openlm-research/open_llama_3b_v2", "openlm-research/open_llama_13b", "huggyllama/llama-65b", "huggyllama/llama-30b", "huggyllama/llama-13b", "huggyllama/llama-7b", + ], "tokenizer_class": "LlamaTokenizerFast", "fine_tune_strategies": ({"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none",},), + } + + class GenerationConfig: + max_new_tokens: int = 256 + temperature: float = 0.45 + top_p: float = 0.95 + top_k: int = 12 + + class SamplingParams: + best_of: int = 1 + presence_penalty: float = 0.5 + START_LLAMA_COMMAND_DOCSTRING = """\ Run a LLMServer for Llama model. @@ -112,5 +83,7 @@ SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = "[INST]", "[/INST]", "< str: return PROMPT_MAPPING[model_type] + DEFAULT_PROMPT_TEMPLATE = _get_prompt diff --git a/src/openllm/models/llama/modeling_llama.py b/src/openllm/models/llama/modeling_llama.py index 130bc1f5..b104354e 100644 --- a/src/openllm/models/llama/modeling_llama.py +++ b/src/openllm/models/llama/modeling_llama.py @@ -21,22 +21,28 @@ from ..._prompt import process_prompt if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional") logger = logging.getLogger(__name__) + class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaTokenizerFast"]): - __openllm_internal__ = True - def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - _template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None - return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {} - @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {} - def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0] - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), generation_config=self.config.model_construct_env(**attrs).to_generation_config(), stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])), skip_special_tokens=True, clean_up_tokenization_spaces=True) - def embeddings(self, prompts: list[str]) -> LLMEmbeddings: - encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device) - input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"] - with torch.inference_mode(): - data = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1] - mask = attention_mask.unsqueeze(-1).expand(data.size()).float() - masked_embeddings = data * mask - sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1) - return LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item()) + __openllm_internal__ = True + + def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + _template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None + return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {} + + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {} + + def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0] + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), generation_config=self.config.model_construct_env(**attrs).to_generation_config(), stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])), skip_special_tokens=True, clean_up_tokenization_spaces=True) + + def embeddings(self, prompts: list[str]) -> LLMEmbeddings: + encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device) + input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"] + with torch.inference_mode(): + data = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1] + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1) + return LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item()) diff --git a/src/openllm/models/llama/modeling_vllm_llama.py b/src/openllm/models/llama/modeling_vllm_llama.py index 9df47b8b..517a8942 100644 --- a/src/openllm/models/llama/modeling_vllm_llama.py +++ b/src/openllm/models/llama/modeling_vllm_llama.py @@ -19,8 +19,10 @@ from .configuration_llama import DEFAULT_PROMPT_TEMPLATE from ..._prompt import process_prompt if t.TYPE_CHECKING: import vllm, transformers logger = logging.getLogger(__name__) + class VLLMLlama(openllm.LLM["vllm.LLMEngine", "transformers.LlamaTokenizerFast"]): - __openllm_internal__ = True - def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - _template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None - return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {} + __openllm_internal__ = True + + def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + _template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None + return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {} diff --git a/src/openllm/models/mpt/__init__.py b/src/openllm/models/mpt/__init__.py index 8c1168cc..7db58020 100644 --- a/src/openllm/models/mpt/__init__.py +++ b/src/openllm/models/mpt/__init__.py @@ -17,18 +17,24 @@ import typing as t from ...exceptions import MissingDependencyError from ...utils import LazyModule from ...utils import is_torch_available + _import_structure: dict[str, list[str]] = {"configuration_mpt": ["MPTConfig", "START_MPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE", "PROMPT_MAPPING"]} try: - if not is_torch_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_mpt"] = ["MPT"] + if not is_torch_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_mpt"] = ["MPT"] if t.TYPE_CHECKING: - from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_mpt import PROMPT_MAPPING as PROMPT_MAPPING - from .configuration_mpt import START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING - from .configuration_mpt import MPTConfig as MPTConfig - try: - if not is_torch_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_mpt import MPT as MPT -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_mpt import PROMPT_MAPPING as PROMPT_MAPPING + from .configuration_mpt import START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING + from .configuration_mpt import MPTConfig as MPTConfig + try: + if not is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_mpt import MPT as MPT +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/mpt/configuration_mpt.py b/src/openllm/models/mpt/configuration_mpt.py index 88b1cd97..e06359b4 100644 --- a/src/openllm/models/mpt/configuration_mpt.py +++ b/src/openllm/models/mpt/configuration_mpt.py @@ -16,8 +16,9 @@ import typing as t import openllm if t.TYPE_CHECKING: MPTPromptType = t.Literal["default", "instruct", "chat", "storywriter"] else: MPTPromptType = str + class MPTConfig(openllm.LLMConfig): - """MPT is a decoder-style transformer pretrained from scratch on English text and code. + """MPT is a decoder-style transformer pretrained from scratch on English text and code. This model was trained by [MosaicML](https://www.mosaicml.com/). @@ -25,30 +26,18 @@ class MPTConfig(openllm.LLMConfig): on HuggingFace. Refers [HuggingFace's MosaicML page](https://huggingface.co/mosaicml) for more details on specific models. """ - __config__ = { - "name_type": "lowercase", - "trust_remote_code": True, - "url": "https://huggingface.co/mosaicml", - "default_id": "mosaicml/mpt-7b-instruct", - "timeout": int(36e6), - "requirements": ["triton", "einops"], - "architecture": "MPTForCausalLM", - "model_ids": [ - "mosaicml/mpt-7b", - "mosaicml/mpt-7b-instruct", - "mosaicml/mpt-7b-chat", - "mosaicml/mpt-7b-storywriter", - "mosaicml/mpt-30b", - "mosaicml/mpt-30b-instruct", - "mosaicml/mpt-30b-chat", - ], - } - prompt_type: MPTPromptType = openllm.LLMConfig.Field('"default"', description="""Given prompt type for running MPT. Default will be inferred from model name if pretrained.""") - max_sequence_length: int = openllm.LLMConfig.Field(2048, description="Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)") - class GenerationConfig: - max_new_tokens: int = 128 - temperature: float = 0 - top_p: float = 0.8 + __config__ = { + "name_type": "lowercase", "trust_remote_code": True, "url": "https://huggingface.co/mosaicml", "default_id": "mosaicml/mpt-7b-instruct", "timeout": int(36e6), "requirements": ["triton", "einops"], "architecture": "MPTForCausalLM", + "model_ids": ["mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct", "mosaicml/mpt-7b-chat", "mosaicml/mpt-7b-storywriter", "mosaicml/mpt-30b", "mosaicml/mpt-30b-instruct", "mosaicml/mpt-30b-chat",], + } + prompt_type: MPTPromptType = openllm.LLMConfig.Field('"default"', description="""Given prompt type for running MPT. Default will be inferred from model name if pretrained.""") + max_sequence_length: int = openllm.LLMConfig.Field(2048, description="Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)") + + class GenerationConfig: + max_new_tokens: int = 128 + temperature: float = 0 + top_p: float = 0.8 + START_MPT_COMMAND_DOCSTRING = """\ Run a LLMServer for MPT model. @@ -86,5 +75,8 @@ _chat_prompt, _default_prompt, _instruct_prompt = """{instruction}""", """{instr {response_key} """.format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY) PROMPT_MAPPING = {"default": _default_prompt, "instruct": _instruct_prompt, "storywriter": _default_prompt, "chat": _chat_prompt} -def _get_prompt(model_type: str) -> str: return PROMPT_MAPPING[model_type] + +def _get_prompt(model_type: str) -> str: + return PROMPT_MAPPING[model_type] + DEFAULT_PROMPT_TEMPLATE = _get_prompt diff --git a/src/openllm/models/mpt/modeling_mpt.py b/src/openllm/models/mpt/modeling_mpt.py index a62c72e0..da60f365 100644 --- a/src/openllm/models/mpt/modeling_mpt.py +++ b/src/openllm/models/mpt/modeling_mpt.py @@ -23,56 +23,71 @@ from ...utils import generate_labels, is_triton_available if t.TYPE_CHECKING: import transformers, torch else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch") logger = logging.getLogger(__name__) + def get_mpt_config(model_id_or_path: str, max_sequence_length: int, device: torch.device | str | int | None, device_map: str | None = None, trust_remote_code: bool = True) -> transformers.PretrainedConfig: - config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) - if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device) - if hasattr(config, "attn_config") and is_triton_available(): config.attn_config["attn_impl"] = "triton" - else: logger.debug("'triton' is not available, Flash Attention will use the default Torch implementation. For faster inference, make sure to install triton with 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'") - # setting max_seq_len - config.max_seq_len = max_sequence_length - return config + config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) + if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device) + if hasattr(config, "attn_config") and is_triton_available(): config.attn_config["attn_impl"] = "triton" + else: logger.debug("'triton' is not available, Flash Attention will use the default Torch implementation. For faster inference, make sure to install triton with 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'") + # setting max_seq_len + config.max_seq_len = max_sequence_length + return config + class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXTokenizerFast"]): - __openllm_internal__ = True - def llm_post_init(self): self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 - @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"} - def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model: - _, tokenizer_attrs = self.llm_parameters - torch_dtype = attrs.pop("torch_dtype", self.dtype) - device_map = attrs.pop("device_map", None) - attrs.pop("low_cpu_mem_usage", None) - config = get_mpt_config(self.model_id, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code) - tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs) - if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token - model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map, **attrs) - try: return bentoml.transformers.save_model( self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self)) - finally: torch.cuda.empty_cache() - def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel: - torch_dtype = attrs.pop("torch_dtype", self.dtype) - device_map = attrs.pop("device_map", None) - trust_remote_code = attrs.pop("trust_remote_code", True) - config = get_mpt_config(self._bentomodel.path, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code,) - model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, config=config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map, **attrs) - model.tie_weights() - return model - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, prompt_type: MPTPromptType | None = None, use_default_prompt_template: bool = True, **attrs: t.Any,) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - _template = None - if use_default_prompt_template: - if prompt_type is None: - if "instruct" in self.model_id: prompt_type = "instruct" - elif "storywriter" in self.model_id: prompt_type = "storywriter" - elif "chat" in self.model_id: prompt_type = "chat" - else: prompt_type = "default" - _template = DEFAULT_PROMPT_TEMPLATE(prompt_type) - return process_prompt(prompt, _template, use_default_prompt_template), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: return generation_result[0] - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - llm_config = self.config.model_construct_env(**attrs) - inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) - attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()} - with torch.inference_mode(): - if torch.cuda.is_available(): - with torch.autocast("cuda", torch.float16): - generated_tensors = self.model.generate(**inputs, **attrs) - else: generated_tensors = self.model.generate(**inputs, **attrs) - return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True) + __openllm_internal__ = True + + def llm_post_init(self): + self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: + return {"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"} + + def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model: + _, tokenizer_attrs = self.llm_parameters + torch_dtype = attrs.pop("torch_dtype", self.dtype) + device_map = attrs.pop("device_map", None) + attrs.pop("low_cpu_mem_usage", None) + config = get_mpt_config(self.model_id, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code) + tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs) + if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token + model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map, **attrs) + try: + return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self)) + finally: + torch.cuda.empty_cache() + + def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel: + torch_dtype = attrs.pop("torch_dtype", self.dtype) + device_map = attrs.pop("device_map", None) + trust_remote_code = attrs.pop("trust_remote_code", True) + config = get_mpt_config(self._bentomodel.path, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code,) + model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, config=config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map, **attrs) + model.tie_weights() + return model + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, prompt_type: MPTPromptType | None = None, use_default_prompt_template: bool = True, **attrs: t.Any,) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + _template = None + if use_default_prompt_template: + if prompt_type is None: + if "instruct" in self.model_id: prompt_type = "instruct" + elif "storywriter" in self.model_id: prompt_type = "storywriter" + elif "chat" in self.model_id: prompt_type = "chat" + else: prompt_type = "default" + _template = DEFAULT_PROMPT_TEMPLATE(prompt_type) + return process_prompt(prompt, _template, use_default_prompt_template), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: + return generation_result[0] + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + llm_config = self.config.model_construct_env(**attrs) + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()} + with torch.inference_mode(): + if torch.cuda.is_available(): + with torch.autocast("cuda", torch.float16): + generated_tensors = self.model.generate(**inputs, **attrs) + else: + generated_tensors = self.model.generate(**inputs, **attrs) + return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True) diff --git a/src/openllm/models/opt/__init__.py b/src/openllm/models/opt/__init__.py index 1b7711c3..b903cef2 100644 --- a/src/openllm/models/opt/__init__.py +++ b/src/openllm/models/opt/__init__.py @@ -20,41 +20,59 @@ from ...utils import is_flax_available from ...utils import is_tf_available from ...utils import is_torch_available from ...utils import is_vllm_available + _import_structure: dict[str, list[str]] = {"configuration_opt": ["OPTConfig", "START_OPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]} try: - if not is_torch_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_opt"] = ["OPT"] + if not is_torch_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_opt"] = ["OPT"] try: - if not is_flax_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_flax_opt"] = ["FlaxOPT"] + if not is_flax_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_flax_opt"] = ["FlaxOPT"] try: - if not is_vllm_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_vllm_opt"] = ["VLLMOPT"] + if not is_vllm_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_vllm_opt"] = ["VLLMOPT"] try: - if not is_tf_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_tf_opt"] = ["TFOPT"] + if not is_tf_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_tf_opt"] = ["TFOPT"] if t.TYPE_CHECKING: - from .configuration_opt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_opt import START_OPT_COMMAND_DOCSTRING as START_OPT_COMMAND_DOCSTRING - from .configuration_opt import OPTConfig as OPTConfig - try: - if not is_torch_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_opt import OPT as OPT - try: - if not is_flax_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_flax_opt import FlaxOPT as FlaxOPT - try: - if not is_vllm_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_vllm_opt import VLLMOPT as VLLMOPT - try: - if not is_tf_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_tf_opt import TFOPT as TFOPT -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_opt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_opt import START_OPT_COMMAND_DOCSTRING as START_OPT_COMMAND_DOCSTRING + from .configuration_opt import OPTConfig as OPTConfig + try: + if not is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_opt import OPT as OPT + try: + if not is_flax_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_flax_opt import FlaxOPT as FlaxOPT + try: + if not is_vllm_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_vllm_opt import VLLMOPT as VLLMOPT + try: + if not is_tf_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_tf_opt import TFOPT as TFOPT +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/opt/configuration_opt.py b/src/openllm/models/opt/configuration_opt.py index 7abe26e5..bb9c27b6 100644 --- a/src/openllm/models/opt/configuration_opt.py +++ b/src/openllm/models/opt/configuration_opt.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations import openllm + class OPTConfig(openllm.LLMConfig): - """OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI. + """OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI. OPT was predominantly pretrained with English text, but a small amount of non-English data is still present within the training corpus via CommonCrawl. The model was pretrained using a causal language modeling (CLM) @@ -23,37 +24,18 @@ class OPTConfig(openllm.LLMConfig): Refer to [OPT's HuggingFace page](https://huggingface.co/docs/transformers/model_doc/opt) for more information. """ - __config__ = { - "name_type": "lowercase", - "trust_remote_code": False, - "url": "https://huggingface.co/docs/transformers/model_doc/opt", - "default_id": "facebook/opt-1.3b", - "architecture": "OPTForCausalLM", - "model_ids": [ - "facebook/opt-125m", - "facebook/opt-350m", - "facebook/opt-1.3b", - "facebook/opt-2.7b", - "facebook/opt-6.7b", - "facebook/opt-66b", - ], - "fine_tune_strategies": ( - { - "adapter_type": "lora", - "r": 16, - "lora_alpha": 32, - "target_modules": ["q_proj", "v_proj"], - "lora_dropout": 0.05, - "bias": "none", - }, - ), - } - format_outputs: bool = openllm.LLMConfig.Field(False, description="""Whether to format the outputs. This can be used when num_return_sequences > 1.""") - class GenerationConfig: - top_k: int = 15 - temperature: float = 0.75 - max_new_tokens: int = 1024 - num_return_sequences: int = 1 + __config__ = { + "name_type": "lowercase", "trust_remote_code": False, "url": "https://huggingface.co/docs/transformers/model_doc/opt", "default_id": "facebook/opt-1.3b", "architecture": "OPTForCausalLM", "model_ids": ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b", "facebook/opt-66b",], + "fine_tune_strategies": ({"adapter_type": "lora", "r": 16, "lora_alpha": 32, "target_modules": ["q_proj", "v_proj"], "lora_dropout": 0.05, "bias": "none",},), + } + format_outputs: bool = openllm.LLMConfig.Field(False, description="""Whether to format the outputs. This can be used when num_return_sequences > 1.""") + + class GenerationConfig: + top_k: int = 15 + temperature: float = 0.75 + max_new_tokens: int = 1024 + num_return_sequences: int = 1 + START_OPT_COMMAND_DOCSTRING = """\ Run a LLMServer for OPT model. diff --git a/src/openllm/models/opt/modeling_flax_opt.py b/src/openllm/models/opt/modeling_flax_opt.py index aad4d6a5..8404b492 100644 --- a/src/openllm/models/opt/modeling_flax_opt.py +++ b/src/openllm/models/opt/modeling_flax_opt.py @@ -22,18 +22,26 @@ from ...utils import generate_labels if t.TYPE_CHECKING: import transformers else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") logger = logging.getLogger(__name__) -class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]): - __openllm_internal__ = True - @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {}, {"padding_side": "left", "truncation_side": "left"} - def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: - config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1]) - tokenizer.pad_token_id = config.pad_token_id - return bentoml.transformers.save_model(self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self)) - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences, "repetition_penalty": repetition_penalty}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: - if len(generation_result) == 1: return generation_result[0] - if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result) - else: return "\n".join(generation_result) - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode( self.model.generate(**self.tokenizer(prompt, return_tensors="np"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences, skip_special_tokens=True) +class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]): + __openllm_internal__ = True + + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: + return {}, {"padding_side": "left", "truncation_side": "left"} + + def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: + config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1]) + tokenizer.pad_token_id = config.pad_token_id + return bentoml.transformers.save_model(self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self)) + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences, "repetition_penalty": repetition_penalty}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: + if len(generation_result) == 1: return generation_result[0] + if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result) + else: return "\n".join(generation_result) + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="np"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences, skip_special_tokens=True) diff --git a/src/openllm/models/opt/modeling_opt.py b/src/openllm/models/opt/modeling_opt.py index 0a17076d..c9765604 100644 --- a/src/openllm/models/opt/modeling_opt.py +++ b/src/openllm/models/opt/modeling_opt.py @@ -18,23 +18,34 @@ import openllm from .configuration_opt import DEFAULT_PROMPT_TEMPLATE from ..._prompt import process_prompt if t.TYPE_CHECKING: - import torch, transformers + import torch, transformers else: - torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") + torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") logger = logging.getLogger(__name__) + class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer"]): - __openllm_internal__ = True - def llm_post_init(self): self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32 - @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left", "truncation_side": "left"} - def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM: - torch_dtype = attrs.pop("torch_dtype", self.dtype) - model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, torch_dtype=torch_dtype, **attrs) - return model - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: - if len(generation_result) == 1: return generation_result[0] - if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result) - else: return "\n".join(generation_result) - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True) + __openllm_internal__ = True + + def llm_post_init(self): + self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: + return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left", "truncation_side": "left"} + + def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM: + torch_dtype = attrs.pop("torch_dtype", self.dtype) + model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, torch_dtype=torch_dtype, **attrs) + return model + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: + if len(generation_result) == 1: return generation_result[0] + if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result) + else: return "\n".join(generation_result) + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + with torch.inference_mode(): + return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True) diff --git a/src/openllm/models/opt/modeling_tf_opt.py b/src/openllm/models/opt/modeling_tf_opt.py index 7169a257..c22aad2c 100644 --- a/src/openllm/models/opt/modeling_tf_opt.py +++ b/src/openllm/models/opt/modeling_tf_opt.py @@ -22,17 +22,26 @@ from ...utils import generate_labels if t.TYPE_CHECKING: import transformers else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") logger = logging.getLogger(__name__) + class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]): - __openllm_internal__ = True - @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {}, {"padding_side": "left", "truncation_side": "left"} - def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: - config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1]) - tokenizer.pad_token_id = config.pad_token_id - return bentoml.transformers.save_model(self.tag, transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self)) - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: - if len(generation_result) == 1: return generation_result[0] - if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result) - else: return "\n".join(generation_result) - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="tf"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True) + __openllm_internal__ = True + + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: + return {}, {"padding_side": "left", "truncation_side": "left"} + + def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: + config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1]) + tokenizer.pad_token_id = config.pad_token_id + return bentoml.transformers.save_model(self.tag, transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self)) + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: + if len(generation_result) == 1: return generation_result[0] + if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result) + else: return "\n".join(generation_result) + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="tf"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True) diff --git a/src/openllm/models/opt/modeling_vllm_opt.py b/src/openllm/models/opt/modeling_vllm_opt.py index e8b573b2..482ace1b 100644 --- a/src/openllm/models/opt/modeling_vllm_opt.py +++ b/src/openllm/models/opt/modeling_vllm_opt.py @@ -19,9 +19,13 @@ import openllm from .configuration_opt import DEFAULT_PROMPT_TEMPLATE from ..._prompt import process_prompt if t.TYPE_CHECKING: - import vllm, transformers + import vllm, transformers + logger = logging.getLogger(__name__) + class VLLMOPT(openllm.LLM["vllm.LLMEngine", "transformers.GPT2Tokenizer"]): - __openllm_internal__ = True - tokenizer_id = "local" - def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {} + __openllm_internal__ = True + tokenizer_id = "local" + + def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {} diff --git a/src/openllm/models/stablelm/__init__.py b/src/openllm/models/stablelm/__init__.py index fbcbc5f6..1d8405d2 100644 --- a/src/openllm/models/stablelm/__init__.py +++ b/src/openllm/models/stablelm/__init__.py @@ -17,17 +17,23 @@ import typing as t from ...exceptions import MissingDependencyError from ...utils import LazyModule from ...utils import is_torch_available + _import_structure: dict[str, list[str]] = {"configuration_stablelm": ["StableLMConfig", "START_STABLELM_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]} try: - if not is_torch_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_stablelm"] = ["StableLM"] + if not is_torch_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_stablelm"] = ["StableLM"] if t.TYPE_CHECKING: - from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_stablelm import START_STABLELM_COMMAND_DOCSTRING as START_STABLELM_COMMAND_DOCSTRING - from .configuration_stablelm import StableLMConfig as StableLMConfig - try: - if not is_torch_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_stablelm import StableLM as StableLM -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_stablelm import START_STABLELM_COMMAND_DOCSTRING as START_STABLELM_COMMAND_DOCSTRING + from .configuration_stablelm import StableLMConfig as StableLMConfig + try: + if not is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_stablelm import StableLM as StableLM +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/stablelm/configuration_stablelm.py b/src/openllm/models/stablelm/configuration_stablelm.py index aa21ca23..7e40a44d 100644 --- a/src/openllm/models/stablelm/configuration_stablelm.py +++ b/src/openllm/models/stablelm/configuration_stablelm.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations import openllm + class StableLMConfig(openllm.LLMConfig): - """StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models. + """StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models. It is pre-trained on a diverse collection of English datasets with a sequence length of 4096 to push beyond the context window limitations of existing open-source language models. @@ -27,23 +28,14 @@ class StableLMConfig(openllm.LLMConfig): and [StableLM-base's model card](https://huggingface.co/stabilityai/stablelm-base-alpha-7b) for more information. """ - __config__ = { - "name_type": "lowercase", - "url": "https://github.com/Stability-AI/StableLM", - "architecture": "GPTNeoXForCausalLM", - "default_id": "stabilityai/stablelm-tuned-alpha-3b", - "model_ids": [ - "stabilityai/stablelm-tuned-alpha-3b", - "stabilityai/stablelm-tuned-alpha-7b", - "stabilityai/stablelm-base-alpha-3b", - "stabilityai/stablelm-base-alpha-7b", - ], - } - class GenerationConfig: - temperature: float = 0.9 - max_new_tokens: int = 128 - top_k: int = 0 - top_p: float = 0.9 + __config__ = {"name_type": "lowercase", "url": "https://github.com/Stability-AI/StableLM", "architecture": "GPTNeoXForCausalLM", "default_id": "stabilityai/stablelm-tuned-alpha-3b", "model_ids": ["stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", "stabilityai/stablelm-base-alpha-3b", "stabilityai/stablelm-base-alpha-7b",],} + + class GenerationConfig: + temperature: float = 0.9 + max_new_tokens: int = 128 + top_k: int = 0 + top_p: float = 0.9 + START_STABLELM_COMMAND_DOCSTRING = """\ Run a LLMServer for StableLM model. diff --git a/src/openllm/models/stablelm/modeling_stablelm.py b/src/openllm/models/stablelm/modeling_stablelm.py index 73e6d406..18a1d85f 100644 --- a/src/openllm/models/stablelm/modeling_stablelm.py +++ b/src/openllm/models/stablelm/modeling_stablelm.py @@ -21,17 +21,28 @@ from ..._prompt import process_prompt if t.TYPE_CHECKING: import transformers, torch else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch") logger = logging.getLogger(__name__) + class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]): - __openllm_internal__ = True - def llm_post_init(self): self.bettertransformer = True if not torch.cuda.is_available() else False - @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {} - def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - if "tuned" in self._model_id and use_default_prompt_template: - system_prompt = attrs.pop("system_prompt", SYSTEM_PROMPT) - prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs) - else: prompt_text = prompt - return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {} - def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0] - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - with torch.inference_mode(): return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))[0], skip_special_tokens=True)] + __openllm_internal__ = True + + def llm_post_init(self): + self.bettertransformer = True if not torch.cuda.is_available() else False + + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: + return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {} + + def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + if "tuned" in self._model_id and use_default_prompt_template: + system_prompt = attrs.pop("system_prompt", SYSTEM_PROMPT) + prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs) + else: + prompt_text = prompt + return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {} + + def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: + return generation_result[0] + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + with torch.inference_mode(): + return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))[0], skip_special_tokens=True)] diff --git a/src/openllm/models/starcoder/__init__.py b/src/openllm/models/starcoder/__init__.py index 051af027..1b6f0587 100644 --- a/src/openllm/models/starcoder/__init__.py +++ b/src/openllm/models/starcoder/__init__.py @@ -17,17 +17,23 @@ import typing as t from ...exceptions import MissingDependencyError from ...utils import LazyModule from ...utils import is_torch_available + _import_structure: dict[str, list[str]] = {"configuration_starcoder": ["StarCoderConfig", "START_STARCODER_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]} try: - if not is_torch_available(): raise MissingDependencyError -except MissingDependencyError: pass -else: _import_structure["modeling_starcoder"] = ["StarCoder"] + if not is_torch_available(): raise MissingDependencyError +except MissingDependencyError: + pass +else: + _import_structure["modeling_starcoder"] = ["StarCoder"] if t.TYPE_CHECKING: - from .configuration_starcoder import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE - from .configuration_starcoder import START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING - from .configuration_starcoder import StarCoderConfig as StarCoderConfig - try: - if not is_torch_available(): raise MissingDependencyError - except MissingDependencyError: pass - else: from .modeling_starcoder import StarCoder as StarCoder -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + from .configuration_starcoder import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE + from .configuration_starcoder import START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING + from .configuration_starcoder import StarCoderConfig as StarCoderConfig + try: + if not is_torch_available(): raise MissingDependencyError + except MissingDependencyError: + pass + else: + from .modeling_starcoder import StarCoder as StarCoder +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/starcoder/configuration_starcoder.py b/src/openllm/models/starcoder/configuration_starcoder.py index 7349b673..23f048e7 100644 --- a/src/openllm/models/starcoder/configuration_starcoder.py +++ b/src/openllm/models/starcoder/configuration_starcoder.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations import openllm + class StarCoderConfig(openllm.LLMConfig): - """The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded. + """The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded. The model uses [Multi Query Attention](https://arxiv.org/abs/1911.02150), [a context window of 8192 tokens](https://arxiv.org/abs/2205.14135), and was trained using the @@ -22,24 +23,17 @@ class StarCoderConfig(openllm.LLMConfig): Refer to [StarCoder's model card](https://huggingface.co/bigcode/starcoder) for more information. """ - __config__ = { - "name_type": "lowercase", - "requires_gpu": True, - "url": "https://github.com/bigcode-project/starcoder", - "architecture": "GPTBigCodeForCausalLM", - "requirements": ["bitsandbytes"], - "workers_per_resource": 0.5, - "default_id": "bigcode/starcoder", - "model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"], - } - class GenerationConfig: - temperature: float = 0.2 - max_new_tokens: int = 256 - min_new_tokens: int = 32 - top_k: float = 50 - top_p: float = 0.95 - pad_token_id: int = 49152 - repetition_penalty: float = 1.2 + __config__ = {"name_type": "lowercase", "requires_gpu": True, "url": "https://github.com/bigcode-project/starcoder", "architecture": "GPTBigCodeForCausalLM", "requirements": ["bitsandbytes"], "workers_per_resource": 0.5, "default_id": "bigcode/starcoder", "model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"],} + + class GenerationConfig: + temperature: float = 0.2 + max_new_tokens: int = 256 + min_new_tokens: int = 32 + top_k: float = 50 + top_p: float = 0.95 + pad_token_id: int = 49152 + repetition_penalty: float = 1.2 + START_STARCODER_COMMAND_DOCSTRING = """\ Run a LLMServer for StarCoder model. diff --git a/src/openllm/models/starcoder/modeling_starcoder.py b/src/openllm/models/starcoder/modeling_starcoder.py index 10677c72..b9a2d894 100644 --- a/src/openllm/models/starcoder/modeling_starcoder.py +++ b/src/openllm/models/starcoder/modeling_starcoder.py @@ -18,47 +18,61 @@ import bentoml import openllm from ...utils import generate_labels if t.TYPE_CHECKING: - import torch, transformers + import torch, transformers else: - torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") + torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") logger = logging.getLogger(__name__) FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = "", "", "", "", "<|endoftext|>", "" + class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]): - __openllm_internal__ = True - @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"} - def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: - torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto") - tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1]) - tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD}) - model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs) - try: return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self)) - finally: torch.cuda.empty_cache() - def sanitize_parameters(self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None - if fim_mode: - try: prefix, suffix = prompt.split(FIM_INDICATOR) - except Exception as err: raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err - prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" - else: prompt_text = prompt - # XXX: This value for pad_token_id is currently a hack, need more investigate why the - # default starcoder doesn't include the same value as santacoder EOD - return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - with torch.inference_mode(): - # eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder - # NOTE: support fine-tuning starcoder - result_tensor = self.model.generate(self.tokenizer.encode(prompt, return_tensors="pt").to(self.device), do_sample=True, pad_token_id=self.tokenizer.eos_token_id, generation_config=self.config.model_construct_env(**attrs).to_generation_config()) - # TODO: We will probably want to return the tokenizer here so that we can manually process this - # return (skip_special_tokens=False, clean_up_tokenization_spaces=False)) - return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) - def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]: - max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device) - src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([])) - stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) - result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) - # Inference API returns the stop sequence - for stop_seq in stop: - if result.endswith(stop_seq): result = result[: -len(stop_seq)] - return [{"generated_text": result}] + __openllm_internal__ = True + + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: + return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"} + + def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: + torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto") + tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1]) + tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD}) + model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs) + try: + return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self)) + finally: + torch.cuda.empty_cache() + + def sanitize_parameters(self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None + if fim_mode: + try: + prefix, suffix = prompt.split(FIM_INDICATOR) + except Exception as err: + raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err + prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" + else: + prompt_text = prompt + # XXX: This value for pad_token_id is currently a hack, need more investigate why the + # default starcoder doesn't include the same value as santacoder EOD + return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {} + + def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: + return generation_result[0] + + def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + with torch.inference_mode(): + # eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder + # NOTE: support fine-tuning starcoder + result_tensor = self.model.generate(self.tokenizer.encode(prompt, return_tensors="pt").to(self.device), do_sample=True, pad_token_id=self.tokenizer.eos_token_id, generation_config=self.config.model_construct_env(**attrs).to_generation_config()) + # TODO: We will probably want to return the tokenizer here so that we can manually process this + # return (skip_special_tokens=False, clean_up_tokenization_spaces=False)) + return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) + + def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]: + max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device) + src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([])) + stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) + result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) + # Inference API returns the stop sequence + for stop_seq in stop: + if result.endswith(stop_seq): result = result[:-len(stop_seq)] + return [{"generated_text": result}] diff --git a/src/openllm/playground/falcon_tuned.py b/src/openllm/playground/falcon_tuned.py index 3d423afd..ecf743a6 100644 --- a/src/openllm/playground/falcon_tuned.py +++ b/src/openllm/playground/falcon_tuned.py @@ -11,7 +11,6 @@ import torch import openllm import transformers - # Make sure to have at least one GPU to run this script openllm.utils.configure_logging() @@ -21,90 +20,54 @@ logger = logging.getLogger(__name__) # On notebook, make sure to install the following # ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git - from datasets import load_dataset from trl import SFTTrainer - DEFAULT_MODEL_ID = "ybelkada/falcon-7b-sharded-bf16" DATASET_NAME = "timdettmers/openassistant-guanaco" - @dataclasses.dataclass class TrainingArguments: - per_device_train_batch_size: int = dataclasses.field(default=4) - gradient_accumulation_steps: int = dataclasses.field(default=4) - optim: str = dataclasses.field(default="paged_adamw_32bit") - save_steps: int = dataclasses.field(default=10) - warmup_steps: int = dataclasses.field(default=10) - max_steps: int = dataclasses.field(default=500) - logging_steps: int = dataclasses.field(default=10) - learning_rate: float = dataclasses.field(default=2e-4) - max_grad_norm: float = dataclasses.field(default=0.3) - warmup_ratio: float = dataclasses.field(default=0.03) - fp16: bool = dataclasses.field(default=True) - group_by_length: bool = dataclasses.field(default=True) - lr_scheduler_type: str = dataclasses.field(default="constant") - output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "falcon")) - + per_device_train_batch_size: int = dataclasses.field(default=4) + gradient_accumulation_steps: int = dataclasses.field(default=4) + optim: str = dataclasses.field(default="paged_adamw_32bit") + save_steps: int = dataclasses.field(default=10) + warmup_steps: int = dataclasses.field(default=10) + max_steps: int = dataclasses.field(default=500) + logging_steps: int = dataclasses.field(default=10) + learning_rate: float = dataclasses.field(default=2e-4) + max_grad_norm: float = dataclasses.field(default=0.3) + warmup_ratio: float = dataclasses.field(default=0.03) + fp16: bool = dataclasses.field(default=True) + group_by_length: bool = dataclasses.field(default=True) + lr_scheduler_type: str = dataclasses.field(default="constant") + output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "falcon")) @dataclasses.dataclass class ModelArguments: - model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID) - max_sequence_length: int = dataclasses.field(default=512) - + model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID) + max_sequence_length: int = dataclasses.field(default=512) parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: - model_args, training_args = t.cast( - t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses() - ) + model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()) -model, tokenizer = openllm.AutoLLM.for_model( - "falcon", - model_id=model_args.model_id, - quantize="int4", - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.float16, - ensure_available=True, -).prepare_for_training( - adapter_type="lora", - lora_alpha=16, - lora_dropout=0.1, - r=16, - bias="none", - target_modules=[ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", - ], -) +model, tokenizer = openllm.AutoLLM.for_model("falcon", model_id=model_args.model_id, quantize="int4", bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ensure_available=True,).prepare_for_training(adapter_type="lora", lora_alpha=16, lora_dropout=0.1, r=16, bias="none", target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h",],) model.config.use_cache = False tokenizer.pad_token = tokenizer.eos_token dataset = load_dataset(DATASET_NAME, split="train") -trainer = SFTTrainer( - model=model, - train_dataset=dataset, - dataset_text_field="text", - max_seq_length=model_args.max_sequence_length, - tokenizer=tokenizer, - args=dataclasses.replace( - transformers.TrainingArguments(training_args.output_dir), - **dataclasses.asdict(training_args), - ), -) +trainer = SFTTrainer(model=model, train_dataset=dataset, dataset_text_field="text", max_seq_length=model_args.max_sequence_length, tokenizer=tokenizer, args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args),),) # upcast layernorm in float32 for more stable training for name, module in trainer.model.named_modules(): - if "norm" in name: - module = module.to(torch.float32) + if "norm" in name: + module = module.to(torch.float32) trainer.train() diff --git a/src/openllm/playground/features.py b/src/openllm/playground/features.py index c47075e0..d1f8d1e6 100644 --- a/src/openllm/playground/features.py +++ b/src/openllm/playground/features.py @@ -5,7 +5,6 @@ import typing as t import openllm - openllm.utils.configure_logging() logger = logging.getLogger(__name__) @@ -15,45 +14,42 @@ MAX_NEW_TOKENS = 384 Q = "Answer the following question, step by step:\n{q}\nA:" question = "What is the meaning of life?" - def main() -> int: - parser = argparse.ArgumentParser() - parser.add_argument("question", default=question) + parser = argparse.ArgumentParser() + parser.add_argument("question", default=question) - if openllm.utils.in_notebook(): - args = parser.parse_args(args=[question]) - else: - args = parser.parse_args() + if openllm.utils.in_notebook(): + args = parser.parse_args(args=[question]) + else: + args = parser.parse_args() - model = openllm.AutoLLM.for_model("opt", model_id="facebook/opt-2.7b", ensure_available=True) - prompt = Q.format(q=args.question) + model = openllm.AutoLLM.for_model("opt", model_id="facebook/opt-2.7b", ensure_available=True) + prompt = Q.format(q=args.question) - logger.info("-" * 50, "Running with 'generate()'", "-" * 50) - res = model.generate(prompt, max_new_tokens=MAX_NEW_TOKENS) - logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res)) + logger.info("-" * 50, "Running with 'generate()'", "-" * 50) + res = model.generate(prompt, max_new_tokens=MAX_NEW_TOKENS) + logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res)) - logger.info("-" * 50, "Running with 'generate()' with per-requests argument", "-" * 50) - res = model.generate(prompt, num_return_sequences=3) - logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res)) + logger.info("-" * 50, "Running with 'generate()' with per-requests argument", "-" * 50) + res = model.generate(prompt, num_return_sequences=3) + logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res)) - logger.info("-" * 50, "Using Runner abstraction with runner.generate.run()", "-" * 50) - r = openllm.Runner("opt", model_id="facebook/opt-350m", init_local=True) - res = r.generate.run(prompt) - logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res)) + logger.info("-" * 50, "Using Runner abstraction with runner.generate.run()", "-" * 50) + r = openllm.Runner("opt", model_id="facebook/opt-350m", init_local=True) + res = r.generate.run(prompt) + logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res)) - logger.info("-" * 50, "Using Runner abstraction with runner()", "-" * 50) - res = r(prompt) - logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res)) - - return 0 + logger.info("-" * 50, "Using Runner abstraction with runner()", "-" * 50) + res = r(prompt) + logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res)) + return 0 def _mp_fn(index: t.Any): # noqa # type: ignore - # For xla_spawn (TPUs) - main() - + # For xla_spawn (TPUs) + main() if openllm.utils.in_notebook(): - main() + main() else: - raise SystemExit(main()) + raise SystemExit(main()) diff --git a/src/openllm/playground/llama2_qlora.py b/src/openllm/playground/llama2_qlora.py index 5f51c9a9..baae543c 100644 --- a/src/openllm/playground/llama2_qlora.py +++ b/src/openllm/playground/llama2_qlora.py @@ -11,7 +11,7 @@ import torch import transformers if t.TYPE_CHECKING: - import peft + import peft # Make sure to have at least one GPU to run this script @@ -29,19 +29,17 @@ from itertools import chain from functools import partial from random import randrange - # COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py def find_all_linear_names(model): - lora_module_names = set() - for name, module in model.named_modules(): - if isinstance(module, bnb.nn.Linear4bit): - names = name.split(".") - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - if "lm_head" in lora_module_names: # needed for 16-bit - lora_module_names.remove("lm_head") - return list(lora_module_names) + lora_module_names = set() + for name, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) # Change this to the local converted path if you don't have access to the meta-llama model DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf" @@ -49,202 +47,149 @@ DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf" DEFAULT_MODEL_VERSION = "335a02887eb6684d487240bbc28b5699298c3135" DATASET_NAME = "databricks/databricks-dolly-15k" - def format_dolly(sample): - instruction = f"### Instruction\n{sample['instruction']}" - context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None - response = f"### Answer\n{sample['response']}" - # join all the parts together - prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) - return prompt - + instruction = f"### Instruction\n{sample['instruction']}" + context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None + response = f"### Answer\n{sample['response']}" + # join all the parts together + prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) + return prompt # template dataset to add prompt to each sample def template_dataset(sample, tokenizer): - sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}" - return sample - + sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}" + return sample # empty list to save remainder from batches to use in next batch remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []} - def chunk(sample, chunk_length=2048): - # define global remainder variable to save remainder from batches to use in next batch - global remainder - # Concatenate all texts and add remainder from previous batch - concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()} - concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()} - # get total number of tokens for batch - batch_total_length = len(concatenated_examples[list(sample.keys())[0]]) + # define global remainder variable to save remainder from batches to use in next batch + global remainder + # Concatenate all texts and add remainder from previous batch + concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()} + concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()} + # get total number of tokens for batch + batch_total_length = len(concatenated_examples[list(sample.keys())[0]]) - # get max number of chunks for batch - if batch_total_length >= chunk_length: - batch_chunk_length = (batch_total_length // chunk_length) * chunk_length - - # Split by chunks of max_len. - result = { - k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] - for k, t in concatenated_examples.items() - } - # add remainder to global variable for next batch - remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()} - # prepare labels - result["labels"] = result["input_ids"].copy() - return result + # get max number of chunks for batch + if batch_total_length >= chunk_length: + batch_chunk_length = (batch_total_length//chunk_length) * chunk_length + # Split by chunks of max_len. + result = {k: [t[i:i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] for k, t in concatenated_examples.items()} + # add remainder to global variable for next batch + remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()} + # prepare labels + result["labels"] = result["input_ids"].copy() + return result def prepare_datasets(tokenizer, dataset_name=DATASET_NAME): - # Load dataset from the hub - dataset = load_dataset(dataset_name, split="train") + # Load dataset from the hub + dataset = load_dataset(dataset_name, split="train") - print(f"dataset size: {len(dataset)}") - print(dataset[randrange(len(dataset))]) + print(f"dataset size: {len(dataset)}") + print(dataset[randrange(len(dataset))]) - # apply prompt template per sample - dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features)) - # print random sample - print("Sample from dolly-v2 ds:", dataset[randint(0, len(dataset))]["text"]) + # apply prompt template per sample + dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features)) + # print random sample + print("Sample from dolly-v2 ds:", dataset[randint(0, len(dataset))]["text"]) - # tokenize and chunk dataset - lm_dataset = dataset.map( - lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features) - ).map( - partial(chunk, chunk_length=2048), - batched=True, - ) - - # Print total number of samples - print(f"Total number of samples: {len(lm_dataset)}") - return lm_dataset + # tokenize and chunk dataset + lm_dataset = dataset.map(lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)).map(partial(chunk, chunk_length=2048), batched=True,) + # Print total number of samples + print(f"Total number of samples: {len(lm_dataset)}") + return lm_dataset @openllm.utils.requires_dependencies("peft", extra="fine-tune") -def prepare_for_int4_training( - model_id: str, - model_version: str | None = None, - gradient_checkpointing: bool = True, - bf16: bool = True, -) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]: - from peft.tuners.lora import LoraLayer +def prepare_for_int4_training(model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True,) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]: + from peft.tuners.lora import LoraLayer - llm = openllm.AutoLLM.for_model( - "llama", - model_id=model_id, - model_version=model_version, - ensure_available=True, - quantize="int4", - bnb_4bit_compute_dtype=torch.bfloat16, - use_cache=not gradient_checkpointing, - device_map="auto", - ) - print("Model summary:", llm.model) + llm = openllm.AutoLLM.for_model("llama", model_id=model_id, model_version=model_version, ensure_available=True, quantize="int4", bnb_4bit_compute_dtype=torch.bfloat16, use_cache=not gradient_checkpointing, device_map="auto",) + print("Model summary:", llm.model) - # get lora target modules - modules = find_all_linear_names(llm.model) - print(f"Found {len(modules)} modules to quantize: {modules}") + # get lora target modules + modules = find_all_linear_names(llm.model) + print(f"Found {len(modules)} modules to quantize: {modules}") - model, tokenizer = llm.prepare_for_training( - adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules - ) - - # pre-process the model by upcasting the layer norms in float 32 for - for name, module in model.named_modules(): - if isinstance(module, LoraLayer): - if bf16: - module = module.to(torch.bfloat16) - if "norm" in name: - module = module.to(torch.float32) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - if bf16 and module.weight.dtype == torch.float32: - module = module.to(torch.bfloat16) - return model, tokenizer + model, tokenizer = llm.prepare_for_training(adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules) + # pre-process the model by upcasting the layer norms in float 32 for + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if bf16: + module = module.to(torch.bfloat16) + if "norm" in name: + module = module.to(torch.float32) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + if bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + return model, tokenizer @dataclasses.dataclass class TrainingArguments: - per_device_train_batch_size: int = dataclasses.field(default=1) - gradient_checkpointing: bool = dataclasses.field(default=True) - bf16: bool = dataclasses.field(default=torch.cuda.get_device_capability()[0] == 8) - learning_rate: float = dataclasses.field(default=5e-5) - num_train_epochs: int = dataclasses.field(default=3) - logging_steps: int = dataclasses.field(default=1) - report_to: str = dataclasses.field(default="none") - output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "llama")) - save_strategy: str = dataclasses.field(default="no") - + per_device_train_batch_size: int = dataclasses.field(default=1) + gradient_checkpointing: bool = dataclasses.field(default=True) + bf16: bool = dataclasses.field(default=torch.cuda.get_device_capability()[0] == 8) + learning_rate: float = dataclasses.field(default=5e-5) + num_train_epochs: int = dataclasses.field(default=3) + logging_steps: int = dataclasses.field(default=1) + report_to: str = dataclasses.field(default="none") + output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "llama")) + save_strategy: str = dataclasses.field(default="no") @dataclasses.dataclass class ModelArguments: - model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID) - model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION) - seed: int = dataclasses.field(default=42) - merge_weights: bool = dataclasses.field(default=False) - + model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID) + model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION) + seed: int = dataclasses.field(default=42) + merge_weights: bool = dataclasses.field(default=False) if openllm.utils.in_notebook(): - model_args, training_rags = ModelArguments(), TrainingArguments() + model_args, training_rags = ModelArguments(), TrainingArguments() else: - parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: - model_args, training_args = t.cast( - t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses() - ) - + parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()) # import the model first hand openllm.import_model("llama", model_id=model_args.model_id, model_version=model_args.model_version) - def train_loop(model_args: ModelArguments, training_args: TrainingArguments): - import peft + import peft - transformers.set_seed(model_args.seed) + transformers.set_seed(model_args.seed) - model, tokenizer = prepare_for_int4_training( - model_args.model_id, - gradient_checkpointing=training_args.gradient_checkpointing, - bf16=training_args.bf16, - ) - datasets = prepare_datasets(tokenizer) + model, tokenizer = prepare_for_int4_training(model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16,) + datasets = prepare_datasets(tokenizer) - trainer = transformers.Trainer( - model=model, - args=dataclasses.replace( - transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args) - ), - train_dataset=datasets, - data_collator=transformers.default_data_collator, - ) + trainer = transformers.Trainer(model=model, args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)), train_dataset=datasets, data_collator=transformers.default_data_collator,) - trainer.train() + trainer.train() - if model_args.merge_weights: - # note that this will requires larger GPU as we will load the whole model into memory + if model_args.merge_weights: + # note that this will requires larger GPU as we will load the whole model into memory - # merge adapter weights with base model and save - # save int4 model - trainer.model.save_pretrained(training_args.output_dir, safe_serialization=False) + # merge adapter weights with base model and save + # save int4 model + trainer.model.save_pretrained(training_args.output_dir, safe_serialization=False) - # gc mem - del model, trainer - torch.cuda.empty_cache() - - model = peft.AutoPeftModelForCausalLM.from_pretrained( - training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16 - ) - # merge lora with base weights and save - model = model.merge_and_unload() - model.save_pretrained( - os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB" - ) - else: - trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora")) + # gc mem + del model, trainer + torch.cuda.empty_cache() + model = peft.AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16) + # merge lora with base weights and save + model = model.merge_and_unload() + model.save_pretrained(os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB") + else: + trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora")) train_loop(model_args, training_args) diff --git a/src/openllm/playground/opt_tuned.py b/src/openllm/playground/opt_tuned.py index 1f00e5ef..6488524f 100644 --- a/src/openllm/playground/opt_tuned.py +++ b/src/openllm/playground/opt_tuned.py @@ -20,71 +20,38 @@ logger = logging.getLogger(__name__) from datasets import load_dataset - if t.TYPE_CHECKING: - from peft import PeftModel + from peft import PeftModel DEFAULT_MODEL_ID = "facebook/opt-6.7b" - -def load_trainer( - model: PeftModel, - tokenizer: transformers.GPT2TokenizerFast, - dataset_dict: t.Any, - training_args: TrainingArguments, -): - return transformers.Trainer( - model=model, - train_dataset=dataset_dict["train"], - args=dataclasses.replace( - transformers.TrainingArguments(training_args.output_dir), - **dataclasses.asdict(training_args), - ), - data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), - ) - +def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments,): + return transformers.Trainer(model=model, train_dataset=dataset_dict["train"], args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args),), data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),) @dataclasses.dataclass class TrainingArguments: - per_device_train_batch_size: int = dataclasses.field(default=4) - gradient_accumulation_steps: int = dataclasses.field(default=4) - warmup_steps: int = dataclasses.field(default=10) - max_steps: int = dataclasses.field(default=50) - learning_rate: float = dataclasses.field(default=3e-4) - fp16: bool = dataclasses.field(default=True) - logging_steps: int = dataclasses.field(default=1) - output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "opt")) - + per_device_train_batch_size: int = dataclasses.field(default=4) + gradient_accumulation_steps: int = dataclasses.field(default=4) + warmup_steps: int = dataclasses.field(default=10) + max_steps: int = dataclasses.field(default=50) + learning_rate: float = dataclasses.field(default=3e-4) + fp16: bool = dataclasses.field(default=True) + logging_steps: int = dataclasses.field(default=1) + output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "opt")) @dataclasses.dataclass class ModelArguments: - model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID) - + model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID) parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: - model_args, training_args = t.cast( - t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses() - ) + model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()) - -model, tokenizer = openllm.AutoLLM.for_model( - "opt", - model_id=model_args.model_id, - quantize="int8", - ensure_available=True, -).prepare_for_training( - adapter_type="lora", - r=16, - lora_alpha=32, - target_modules=["q_proj", "v_proj"], - lora_dropout=0.05, - bias="none", -) +model, tokenizer = openllm.AutoLLM.for_model("opt", model_id=model_args.model_id, quantize="int8", ensure_available=True,).prepare_for_training(adapter_type="lora", r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none",) # ft on english_quotes data = load_dataset("Abirate/english_quotes") diff --git a/src/openllm/serialisation/__init__.py b/src/openllm/serialisation/__init__.py index 0cf5e818..f6d9b2fa 100644 --- a/src/openllm/serialisation/__init__.py +++ b/src/openllm/serialisation/__init__.py @@ -49,104 +49,80 @@ from ..utils import LazyLoader from ..utils import LazyModule if t.TYPE_CHECKING: - import bentoml - import transformers + import bentoml + import transformers - from .._llm import M - from .._llm import T + from .._llm import M + from .._llm import T else: - transformers = LazyLoader("transformers", globals(), "transformers") - + transformers = LazyLoader("transformers", globals(), "transformers") def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model: - if llm.runtime == "transformers": - return openllm.transformers.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs) - elif llm.runtime == "ggml": - return openllm.ggml.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs) - else: - raise ValueError(f"Unknown runtime: {llm.config['runtime']}") - + if llm.runtime == "transformers": + return openllm.transformers.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs) + elif llm.runtime == "ggml": + return openllm.ggml.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs) + else: + raise ValueError(f"Unknown runtime: {llm.config['runtime']}") def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model: - if llm.runtime == "transformers": - return openllm.transformers.get(llm, auto_import=auto_import) - elif llm.runtime == "ggml": - return openllm.ggml.get(llm, auto_import=auto_import) - else: - raise ValueError(f"Unknown runtime: {llm.config['runtime']}") - + if llm.runtime == "transformers": + return openllm.transformers.get(llm, auto_import=auto_import) + elif llm.runtime == "ggml": + return openllm.ggml.get(llm, auto_import=auto_import) + else: + raise ValueError(f"Unknown runtime: {llm.config['runtime']}") def save_pretrained(llm: openllm.LLM[M, T], save_directory: str, **attrs: t.Any) -> None: - if llm.runtime == "transformers": - return openllm.transformers.save_pretrained(llm, save_directory, **attrs) - elif llm.runtime == "ggml": - return openllm.ggml.save_pretrained(llm, save_directory, **attrs) - else: - raise ValueError(f"Unknown runtime: {llm.config['runtime']}") - + if llm.runtime == "transformers": + return openllm.transformers.save_pretrained(llm, save_directory, **attrs) + elif llm.runtime == "ggml": + return openllm.ggml.save_pretrained(llm, save_directory, **attrs) + else: + raise ValueError(f"Unknown runtime: {llm.config['runtime']}") def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: - if llm.runtime == "transformers": - return openllm.transformers.load_model(llm, *decls, **attrs) - elif llm.runtime == "ggml": - return openllm.ggml.load_model(llm, *decls, **attrs) - else: - raise ValueError(f"Unknown runtime: {llm.config['runtime']}") - + if llm.runtime == "transformers": + return openllm.transformers.load_model(llm, *decls, **attrs) + elif llm.runtime == "ggml": + return openllm.ggml.load_model(llm, *decls, **attrs) + else: + raise ValueError(f"Unknown runtime: {llm.config['runtime']}") def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T: - """Load the tokenizer from BentoML store. + """Load the tokenizer from BentoML store. - By default, it will try to find the bentomodel whether it is in store.. - If model is not found, it will raises a ``bentoml.exceptions.NotFound``. - """ - from .transformers import infer_tokenizers_class_for_llm + By default, it will try to find the bentomodel whether it is in store.. + If model is not found, it will raises a ``bentoml.exceptions.NotFound``. + """ + from .transformers import infer_tokenizers_class_for_llm - bentomodel_fs = llm._bentomodel._fs - if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME): - with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile: - try: - tokenizer = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"] - except KeyError: - # This could happen if users implement their own import_model - raise OpenLLMException( - "Model does not have tokenizer. Make sure to save \ + bentomodel_fs = llm._bentomodel._fs + if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME): + with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile: + try: + tokenizer = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"] + except KeyError: + # This could happen if users implement their own import_model + raise OpenLLMException("Model does not have tokenizer. Make sure to save \ the tokenizer within the model via 'custom_objects'.\ - For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))" - ) from None - else: - tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained( - bentomodel_fs.getsyspath("/"), - trust_remote_code=llm.__llm_trust_remote_code__, - **tokenizer_attrs, - ) + For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))") from None + else: + tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(bentomodel_fs.getsyspath("/"), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs,) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token - return tokenizer + return tokenizer - -_extras = { - "get": get, - "import_model": import_model, - "save_pretrained": save_pretrained, - "load_model": load_model, - "load_tokenizer": load_tokenizer, -} +_extras = {"get": get, "import_model": import_model, "save_pretrained": save_pretrained, "load_model": load_model, "load_tokenizer": load_tokenizer,} _import_structure: dict[str, list[str]] = {"ggml": [], "transformers": []} if t.TYPE_CHECKING: - from . import ggml as ggml - from . import transformers as transformers + from . import ggml as ggml + from . import transformers as transformers else: - import sys + import sys - sys.modules[__name__] = LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - extra_objects=_extras, - ) + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects=_extras,) diff --git a/src/openllm/serialisation/constants.py b/src/openllm/serialisation/constants.py index d0a5fe2f..699eef53 100644 --- a/src/openllm/serialisation/constants.py +++ b/src/openllm/serialisation/constants.py @@ -17,21 +17,8 @@ from __future__ import annotations FRAMEWORK_TO_AUTOCLASS_MAPPING = { "pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), # NOTE: vllm will use PyTorch implementation of transformers for serialisation - "vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), - "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), - "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"), + "vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"), } - # this logic below is synonymous to handling `from_pretrained` attrs. -HUB_ATTRS = [ - "cache_dir", - "code_revision", - "force_download", - "local_files_only", - "proxies", - "resume_download", - "revision", - "subfolder", - "use_auth_token", -] +HUB_ATTRS = ["cache_dir", "code_revision", "force_download", "local_files_only", "proxies", "resume_download", "revision", "subfolder", "use_auth_token",] diff --git a/src/openllm/serialisation/ggml.py b/src/openllm/serialisation/ggml.py index ce5158fa..21c5f9a2 100644 --- a/src/openllm/serialisation/ggml.py +++ b/src/openllm/serialisation/ggml.py @@ -23,54 +23,42 @@ import bentoml from ..exceptions import OpenLLMException if t.TYPE_CHECKING: - import openllm + import openllm - from .._llm import M + from .._llm import M _conversion_strategy = {"pt": "ggml"} -def import_model( - llm: openllm.LLM[t.Any, t.Any], - *decls: t.Any, - trust_remote_code: bool = True, - **attrs: t.Any, -) -> bentoml.Model: - raise NotImplementedError("Currently work in progress.") - +def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model: + raise NotImplementedError("Currently work in progress.") def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model: - """Return an instance of ``bentoml.Model`` from given LLM instance. + """Return an instance of ``bentoml.Model`` from given LLM instance. - By default, it will try to check the model in the local store. - If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub. - - Otherwise, it will raises a ``bentoml.exceptions.NotFound``. - """ - try: - model = bentoml.models.get(llm.tag) - if model.info.module not in ("openllm.serialisation.ggml", __name__): - raise bentoml.exceptions.NotFound( - f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'." - ) - if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime: - raise OpenLLMException( - f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}." - ) - return model - except bentoml.exceptions.NotFound: - if auto_import: - return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__) - raise + By default, it will try to check the model in the local store. + If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub. + Otherwise, it will raises a ``bentoml.exceptions.NotFound``. + """ + try: + model = bentoml.models.get(llm.tag) + if model.info.module not in ("openllm.serialisation.ggml", __name__): + raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.") + if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime: + raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.") + return model + except bentoml.exceptions.NotFound: + if auto_import: + return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__) + raise def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M: - """Load the model from BentoML store. - - By default, it will try to find check the model in the local store. - If model is not found, it will raises a ``bentoml.exceptions.NotFound``. - """ - raise NotImplementedError("Currently work in progress.") + """Load the model from BentoML store. + By default, it will try to find check the model in the local store. + If model is not found, it will raises a ``bentoml.exceptions.NotFound``. + """ + raise NotImplementedError("Currently work in progress.") def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None: - raise NotImplementedError("Currently work in progress.") + raise NotImplementedError("Currently work in progress.") diff --git a/src/openllm/serialisation/transformers.py b/src/openllm/serialisation/transformers.py index a5ff18aa..699fd8fa 100644 --- a/src/openllm/serialisation/transformers.py +++ b/src/openllm/serialisation/transformers.py @@ -38,237 +38,183 @@ from ..utils import lenient_issubclass from ..utils import normalize_attrs_to_model_tokenizer_pair if t.TYPE_CHECKING: - import auto_gptq as autogptq - import torch - import torch.cuda - import vllm + import auto_gptq as autogptq + import torch + import torch.cuda + import vllm - import openllm - import transformers as _transformers - from transformers.models.auto.auto_factory import _BaseAutoModelClass + import openllm + import transformers as _transformers + from transformers.models.auto.auto_factory import _BaseAutoModelClass - from .._llm import M - from .._llm import T - from .._types import DictStrAny + from .._llm import M + from .._llm import T + from .._types import DictStrAny else: - vllm = LazyLoader("vllm", globals(), "vllm") - autogptq = LazyLoader("autogptq", globals(), "auto_gptq") - _transformers = LazyLoader("_transformers", globals(), "transformers") - torch = LazyLoader("torch", globals(), "torch") - torch.cuda = LazyLoader("torch.cuda", globals(), "torch.cuda") + vllm = LazyLoader("vllm", globals(), "vllm") + autogptq = LazyLoader("autogptq", globals(), "auto_gptq") + _transformers = LazyLoader("_transformers", globals(), "transformers") + torch = LazyLoader("torch", globals(), "torch") + torch.cuda = LazyLoader("torch.cuda", globals(), "torch.cuda") _object_setattr = object.__setattr__ def process_transformers_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[_transformers.PretrainedConfig, dict[str, t.Any], dict[str, t.Any]]: - """Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs.""" - config = attrs.pop("config", None) - # this logic below is synonymous to handling `from_pretrained` attrs. - hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs} - if not isinstance(config, _transformers.PretrainedConfig): - copied_attrs = copy.deepcopy(attrs) - if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype") - config, attrs = _transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs) - return config, hub_attrs, attrs + """Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs.""" + config = attrs.pop("config", None) + # this logic below is synonymous to handling `from_pretrained` attrs. + hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs} + if not isinstance(config, _transformers.PretrainedConfig): + copied_attrs = copy.deepcopy(attrs) + if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype") + config, attrs = _transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs) + return config, hub_attrs, attrs def infer_tokenizers_class_for_llm(__llm: openllm.LLM[t.Any, T]) -> T: - tokenizer_class = __llm.config["tokenizer_class"] - if tokenizer_class is None: tokenizer_class = "AutoTokenizer" - __cls = getattr(_transformers, tokenizer_class) - if __cls is None: raise ValueError(f"{tokenizer_class} is not a valid Tokenizer class from 'transformers.' Set '{__llm}.__config__[\"trust_remote_code\"] = True' and try again.") - return __cls + tokenizer_class = __llm.config["tokenizer_class"] + if tokenizer_class is None: tokenizer_class = "AutoTokenizer" + __cls = getattr(_transformers, tokenizer_class) + if __cls is None: raise ValueError(f"{tokenizer_class} is not a valid Tokenizer class from 'transformers.' Set '{__llm}.__config__[\"trust_remote_code\"] = True' and try again.") + return __cls def infer_autoclass_from_llm_config(llm: openllm.LLM[M, T], config: _transformers.PretrainedConfig) -> _BaseAutoModelClass: - if llm.config["trust_remote_code"]: - autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM" - if not hasattr(config, "auto_map"): raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping") - # in case this model doesn't use the correct auto class for model type, for example like chatglm - # where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel - if autoclass not in config.auto_map: autoclass = "AutoModel" - return getattr(_transformers, autoclass) - else: - if type(config) in _transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0 - elif type(config) in _transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1 - else: raise OpenLLMException(f"Model type {type(config)} is not supported yet.") - return getattr(_transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx]) + if llm.config["trust_remote_code"]: + autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM" + if not hasattr(config, "auto_map"): raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping") + # in case this model doesn't use the correct auto class for model type, for example like chatglm + # where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel + if autoclass not in config.auto_map: autoclass = "AutoModel" + return getattr(_transformers, autoclass) + else: + if type(config) in _transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0 + elif type(config) in _transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1 + else: raise OpenLLMException(f"Model type {type(config)} is not supported yet.") + return getattr(_transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx]) def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model: - """Auto detect model type from given model_id and import it to bentoml's model store. + """Auto detect model type from given model_id and import it to bentoml's model store. - For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first, - returning all of the unused kwargs. - The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants). - For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion. + For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first, + returning all of the unused kwargs. + The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants). + For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion. - Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`. + Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`. - Refer to Transformers documentation for more information about kwargs. + Refer to Transformers documentation for more information about kwargs. - Args: - llm: The LLM instance for this given model. - trust_remote_code: Whether to trust the remote code when loading the model. - *decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants). - **attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants). - """ - config, hub_attrs, attrs = process_transformers_config(llm.model_id, trust_remote_code, **attrs) - _, tokenizer_attrs = llm.llm_parameters - quantize_method = llm._quantize_method - safe_serialisation = first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors") - # Disable safe serialization with vLLM - if llm.__llm_implementation__ == "vllm": safe_serialisation = False - metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method if quantize_method is not None else False} - signatures: DictStrAny = {} - if quantize_method == "gptq": - if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'") - if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") - model = autogptq.AutoGPTQForCausalLM.from_quantized( - llm.model_id, - *decls, - quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), - trust_remote_code=trust_remote_code, - use_safetensors=safe_serialisation, - **hub_attrs, - **attrs, - ) - metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework}) - signatures["generate"] = {"batchable": False} - else: - # this model might be called with --quantize int4, therefore we need to pop this out - # since saving int4 is not yet supported - if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False): attrs.pop("quantization_config") - model = infer_autoclass_from_llm_config(llm, config).from_pretrained( - llm.model_id, - *decls, - config=config, - trust_remote_code=trust_remote_code, - use_safetensors=safe_serialisation, - **hub_attrs, - **attrs, - ) - metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.framework}) + Args: + llm: The LLM instance for this given model. + trust_remote_code: Whether to trust the remote code when loading the model. + *decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants). + **attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants). + """ + config, hub_attrs, attrs = process_transformers_config(llm.model_id, trust_remote_code, **attrs) + _, tokenizer_attrs = llm.llm_parameters + quantize_method = llm._quantize_method + safe_serialisation = first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors") + # Disable safe serialization with vLLM + if llm.__llm_implementation__ == "vllm": safe_serialisation = False + metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method if quantize_method is not None else False} + signatures: DictStrAny = {} + if quantize_method == "gptq": + if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'") + if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") + model = autogptq.AutoGPTQForCausalLM.from_quantized(llm.model_id, *decls, quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), trust_remote_code=trust_remote_code, use_safetensors=safe_serialisation, **hub_attrs, **attrs,) + metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework}) + signatures["generate"] = {"batchable": False} + else: + # this model might be called with --quantize int4, therefore we need to pop this out + # since saving int4 is not yet supported + if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False): attrs.pop("quantization_config") + model = infer_autoclass_from_llm_config(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, use_safetensors=safe_serialisation, **hub_attrs, **attrs,) + metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.framework}) - _tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained( - llm.model_id, - trust_remote_code=trust_remote_code, - **hub_attrs, - **tokenizer_attrs, - ) - if _tokenizer.pad_token is None: _tokenizer.pad_token = _tokenizer.eos_token + _tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs,) + if _tokenizer.pad_token is None: _tokenizer.pad_token = _tokenizer.eos_token - # NOTE: quick hack to set the loaded into llm object to use with save_pretrained - # to avoid recursive call when the model is not yet available in local store - _object_setattr(llm, "__llm_model__", model) - _object_setattr(llm, "__llm_tokenizer__", _tokenizer) - create_kwargs: DictStrAny= dict( - module="openllm.serialisation.transformers", api_version="v1", context=generate_context(framework_name="openllm"), labels=generate_labels(llm), metadata=metadata, - signatures=signatures if signatures else make_default_signatures(model), options=ModelOptions(), - external_modules=[importlib.import_module(model.__module__), importlib.import_module(_tokenizer.__module__)] if trust_remote_code else None, - ) - if "use_tempfs" in inspect.signature(bentoml.models.create).parameters: create_kwargs["use_tempfs"] = False + # NOTE: quick hack to set the loaded into llm object to use with save_pretrained + # to avoid recursive call when the model is not yet available in local store + _object_setattr(llm, "__llm_model__", model) + _object_setattr(llm, "__llm_tokenizer__", _tokenizer) + create_kwargs: DictStrAny = dict( + module="openllm.serialisation.transformers", api_version="v1", context=generate_context(framework_name="openllm"), labels=generate_labels(llm), metadata=metadata, signatures=signatures if signatures else make_default_signatures(model), options=ModelOptions(), external_modules=[importlib.import_module(model.__module__), + importlib.import_module(_tokenizer.__module__)] if trust_remote_code else None, + ) + if "use_tempfs" in inspect.signature(bentoml.models.create).parameters: create_kwargs["use_tempfs"] = False - try: - with bentoml.models.create(llm.tag, **create_kwargs) as bentomodel: - save_pretrained(llm, bentomodel.path, safe_serialization=safe_serialisation) - return bentomodel - finally: - # NOTE: We need to free up the cache after importing the model - # in the case where users first run openllm start without the model - # available locally. - if is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache() + try: + with bentoml.models.create(llm.tag, **create_kwargs) as bentomodel: + save_pretrained(llm, bentomodel.path, safe_serialization=safe_serialisation) + return bentomodel + finally: + # NOTE: We need to free up the cache after importing the model + # in the case where users first run openllm start without the model + # available locally. + if is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache() def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model: - """Return an instance of ``bentoml.Model`` from given LLM instance. + """Return an instance of ``bentoml.Model`` from given LLM instance. - By default, it will try to check the model in the local store. - If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub. + By default, it will try to check the model in the local store. + If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub. - Otherwise, it will raises a ``bentoml.exceptions.NotFound``. - """ - try: - model = bentoml.models.get(llm.tag) - # compat with bentoml.transformers.get - if model.info.module not in ("openllm.serialisation.transformers", __name__): - raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.") - if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime: - raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.") - return model - except bentoml.exceptions.NotFound as err: - if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__) - raise err from None + Otherwise, it will raises a ``bentoml.exceptions.NotFound``. + """ + try: + model = bentoml.models.get(llm.tag) + # compat with bentoml.transformers.get + if model.info.module not in ("openllm.serialisation.transformers", __name__): + raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.") + if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime: + raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.") + return model + except bentoml.exceptions.NotFound as err: + if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__) + raise err from None def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: - """Load the model from BentoML store. + """Load the model from BentoML store. - By default, it will try to find check the model in the local store. - If model is not found, it will raises a ``bentoml.exceptions.NotFound``. - """ - config, hub_attrs, attrs = process_transformers_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs) - safe_serialization = first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors") - if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq": - if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'") - if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") - return autogptq.AutoGPTQForCausalLM.from_quantized( - llm._bentomodel.path, - *decls, - quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), - trust_remote_code=llm.__llm_trust_remote_code__, - use_safetensors=safe_serialization, - **hub_attrs, - **attrs, - ) + By default, it will try to find check the model in the local store. + If model is not found, it will raises a ``bentoml.exceptions.NotFound``. + """ + config, hub_attrs, attrs = process_transformers_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs) + safe_serialization = first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors") + if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq": + if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'") + if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") + return autogptq.AutoGPTQForCausalLM.from_quantized(llm._bentomodel.path, *decls, quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), trust_remote_code=llm.__llm_trust_remote_code__, use_safetensors=safe_serialization, **hub_attrs, **attrs,) - model = infer_autoclass_from_llm_config(llm, config).from_pretrained( - llm._bentomodel.path, - *decls, - config=config, - trust_remote_code=llm.__llm_trust_remote_code__, - **hub_attrs, - **attrs, - ) - # NOTE: we only cast and load the model if it is not already quantized and setup correctly - loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_quantized", False) - if torch.cuda.is_available() and device_count() == 1 and not loaded_in_kbit: - try: model = model.to("cuda") - except torch.cuda.OutOfMemoryError as err: raise RuntimeError(f"Failed to convert {llm.config['model_name']} with model_id '{llm.model_id}' to CUDA.\nNote: You can try out '--quantize int8 | int4' for dynamic quantization.") from err - # BetterTransformer is currently only supported on PyTorch. - if llm.bettertransformer and isinstance(model, _transformers.PreTrainedModel): model = model.to_bettertransformer() - return t.cast("M", model) + model = infer_autoclass_from_llm_config(llm, config).from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, **hub_attrs, **attrs,) + # NOTE: we only cast and load the model if it is not already quantized and setup correctly + loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_quantized", False) + if torch.cuda.is_available() and device_count() == 1 and not loaded_in_kbit: + try: + model = model.to("cuda") + except torch.cuda.OutOfMemoryError as err: + raise RuntimeError(f"Failed to convert {llm.config['model_name']} with model_id '{llm.model_id}' to CUDA.\nNote: You can try out '--quantize int8 | int4' for dynamic quantization.") from err + # BetterTransformer is currently only supported on PyTorch. + if llm.bettertransformer and isinstance(model, _transformers.PreTrainedModel): model = model.to_bettertransformer() + return t.cast("M", model) -def save_pretrained( - llm: openllm.LLM[M, T], - save_directory: str, - is_main_process: bool = True, - state_dict: DictStrAny | None = None, - save_function: t.Callable[..., None] | None = None, - push_to_hub: bool = False, - max_shard_size: int | str = "10GB", - safe_serialization: bool = False, - variant: str | None = None, - **attrs: t.Any, -) -> None: - """Light wrapper around ``transformers.PreTrainedTokenizer.save_pretrained`` and ``transformers.PreTrainedModel.save_pretrained``.""" - save_function = first_not_none(save_function, default=torch.save) - model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs) - safe_serialization = safe_serialization or llm._serialisation_format == "safetensors" - # NOTE: disable safetensors for vllm - if llm.__llm_implementation__ == "vllm": safe_serialization = False - if llm._quantize_method == "gptq": - if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'") - if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") - if not lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})") - t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization) - elif LazyType["vllm.LLMEngine"]("vllm.LLMEngine").isinstance(llm.model): raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.") - elif isinstance(llm.model, _transformers.Pipeline): llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization) - else: - # We can safely cast here since it will be the PreTrainedModel protocol. - t.cast("_transformers.PreTrainedModel", llm.model).save_pretrained( - save_directory, - is_main_process=is_main_process, - state_dict=state_dict, - save_function=save_function, - push_to_hub=push_to_hub, - max_shard_size=max_shard_size, - safe_serialization=safe_serialization, - variant=variant, - **model_save_attrs - ) - llm.tokenizer.save_pretrained(save_directory, push_to_hub=push_to_hub, **tokenizer_save_attrs) +def save_pretrained(llm: openllm.LLM[M, T], save_directory: str, is_main_process: bool = True, state_dict: DictStrAny | None = None, save_function: t.Callable[..., None] | None = None, push_to_hub: bool = False, max_shard_size: int | str = "10GB", safe_serialization: bool = False, variant: str | None = None, **attrs: t.Any,) -> None: + """Light wrapper around ``transformers.PreTrainedTokenizer.save_pretrained`` and ``transformers.PreTrainedModel.save_pretrained``.""" + save_function = first_not_none(save_function, default=torch.save) + model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs) + safe_serialization = safe_serialization or llm._serialisation_format == "safetensors" + # NOTE: disable safetensors for vllm + if llm.__llm_implementation__ == "vllm": safe_serialization = False + if llm._quantize_method == "gptq": + if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'") + if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") + if not lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})") + t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization) + elif LazyType["vllm.LLMEngine"]("vllm.LLMEngine").isinstance(llm.model): + raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.") + elif isinstance(llm.model, _transformers.Pipeline): + llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization) + else: + # We can safely cast here since it will be the PreTrainedModel protocol. + t.cast("_transformers.PreTrainedModel", llm.model).save_pretrained(save_directory, is_main_process=is_main_process, state_dict=state_dict, save_function=save_function, push_to_hub=push_to_hub, max_shard_size=max_shard_size, safe_serialization=safe_serialization, variant=variant, **model_save_attrs) + llm.tokenizer.save_pretrained(save_directory, push_to_hub=push_to_hub, **tokenizer_save_attrs) diff --git a/src/openllm/testing.py b/src/openllm/testing.py index 50b429cd..9035bca0 100644 --- a/src/openllm/testing.py +++ b/src/openllm/testing.py @@ -26,77 +26,50 @@ import openllm logger = logging.getLogger(__name__) if t.TYPE_CHECKING: - from ._types import LiteralRuntime - + from ._types import LiteralRuntime @contextlib.contextmanager -def build_bento( - model: str, - model_id: str | None = None, - quantize: t.Literal["int4", "int8", "gptq"] | None = None, - runtime: t.Literal["ggml", "transformers"] = "transformers", - cleanup: bool = False, -) -> t.Iterator[bentoml.Bento]: - logger.info("Building BentoML for %s", model) - bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime) - yield bento +def build_bento(model: str, model_id: str | None = None, quantize: t.Literal["int4", "int8", "gptq"] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", cleanup: bool = False,) -> t.Iterator[bentoml.Bento]: + logger.info("Building BentoML for %s", model) + bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime) + yield bento + if cleanup: + logger.info("Deleting %s", bento.tag) + bentoml.bentos.delete(bento.tag) + +@contextlib.contextmanager +def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any,) -> t.Iterator[str]: + if isinstance(bento, bentoml.Bento): bento_tag = bento.tag + else: bento_tag = bentoml.Tag.from_taglike(bento) + if image_tag is None: image_tag = str(bento_tag) + + executable = shutil.which("docker") + if not executable: + raise RuntimeError("docker executable not found") + + try: + logger.info("Building container for %s", bento_tag) + bentoml.container.build(bento_tag, backend="docker", image_tag=(image_tag,), progress="plain", **attrs,) + yield image_tag + finally: if cleanup: - logger.info("Deleting %s", bento.tag) - bentoml.bentos.delete(bento.tag) - + logger.info("Deleting container %s", image_tag) + subprocess.check_output([executable, "rmi", "-f", image_tag]) @contextlib.contextmanager -def build_container( - bento: bentoml.Bento | str | bentoml.Tag, - image_tag: str | None = None, - cleanup: bool = False, - **attrs: t.Any, -) -> t.Iterator[str]: - if isinstance(bento, bentoml.Bento): bento_tag = bento.tag - else: bento_tag = bentoml.Tag.from_taglike(bento) - if image_tag is None: image_tag = str(bento_tag) +def prepare(model: str, model_id: str | None = None, implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, cleanup: bool = True,) -> t.Iterator[str]: + if clean_context is None: + clean_context = contextlib.ExitStack() + cleanup = True - executable = shutil.which("docker") - if not executable: - raise RuntimeError("docker executable not found") + llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True) + bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}") - try: - logger.info("Building container for %s", bento_tag) - bentoml.container.build( - bento_tag, - backend="docker", - image_tag=(image_tag,), - progress="plain", - **attrs, - ) - yield image_tag - finally: - if cleanup: - logger.info("Deleting container %s", image_tag) - subprocess.check_output([executable, "rmi", "-f", image_tag]) + if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup)) + else: bento = bentoml.get(bento_tag) + container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_") -@contextlib.contextmanager -def prepare( - model: str, - model_id: str | None = None, - implementation: LiteralRuntime = "pt", - deployment_mode: t.Literal["container", "local"] = "local", - clean_context: contextlib.ExitStack | None = None, - cleanup: bool = True, -) -> t.Iterator[str]: - if clean_context is None: - clean_context = contextlib.ExitStack() - cleanup = True - - llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True) - bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}") - - if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup)) - else: bento = bentoml.get(bento_tag) - - container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_") - - if deployment_mode == "container": container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup)) - yield container_name - if cleanup: clean_context.close() + if deployment_mode == "container": container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup)) + yield container_name + if cleanup: clean_context.close() diff --git a/src/openllm/utils/__init__.py b/src/openllm/utils/__init__.py index 73703460..465905b3 100644 --- a/src/openllm/utils/__init__.py +++ b/src/openllm/utils/__init__.py @@ -50,74 +50,77 @@ from .lazy import LazyModule logger = logging.getLogger(__name__) try: - from typing import GenericAlias as _TypingGenericAlias # type: ignore + from typing import GenericAlias as _TypingGenericAlias # type: ignore except ImportError: - # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) - _TypingGenericAlias = () # type: ignore + # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) + _TypingGenericAlias = () # type: ignore if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,) else: - # _GenericAlias is the actual GenericAlias implementation - _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore + # _GenericAlias is the actual GenericAlias implementation + _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore # NOTE: We need to do this so that overload can register # correct overloads to typing registry if sys.version_info[:2] >= (3, 11): - from typing import overload as _overload + from typing import overload as _overload else: - from typing_extensions import overload as _overload + from typing_extensions import overload as _overload if t.TYPE_CHECKING: - import openllm + import openllm - from .._types import AnyCallable - from .._types import DictStrAny - from .._types import LiteralRuntime + from .._types import AnyCallable + from .._types import DictStrAny + from .._types import LiteralRuntime DEV_DEBUG_VAR = "OPENLLMDEVDEBUG" def set_debug_mode(enabled: bool, level: int = 1) -> None: - # monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs - if enabled: os.environ[DEV_DEBUG_VAR] = str(level) - os.environ[DEBUG_ENV_VAR] = str(enabled) - os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR" + # monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs + if enabled: os.environ[DEV_DEBUG_VAR] = str(level) + os.environ[DEBUG_ENV_VAR] = str(enabled) + os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR" def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool: - try: return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type] - except TypeError: - if isinstance(cls, _WithArgsTypes): return False - raise + try: + return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type] + except TypeError: + if isinstance(cls, _WithArgsTypes): return False + raise def available_devices() -> tuple[str, ...]: - """Return available GPU under system. Currently only supports NVIDIA GPUs.""" - from .._strategies import NvidiaGpuResource - return tuple(NvidiaGpuResource.from_system()) + """Return available GPU under system. Currently only supports NVIDIA GPUs.""" + from .._strategies import NvidiaGpuResource + return tuple(NvidiaGpuResource.from_system()) @functools.lru_cache(maxsize=128) def generate_hash_from_file(f: str, algorithm: t.Literal["md5", "sha1"] = "sha1") -> str: - """Generate a hash from given file's modification time. + """Generate a hash from given file's modification time. - Args: - f: The file to generate the hash from. - algorithm: The hashing algorithm to use. Defaults to 'sha1' (similar to how Git generate its commit hash.) + Args: + f: The file to generate the hash from. + algorithm: The hashing algorithm to use. Defaults to 'sha1' (similar to how Git generate its commit hash.) - Returns: - The generated hash. - """ - return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest() + Returns: + The generated hash. + """ + return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest() @functools.lru_cache(maxsize=1) -def device_count() -> int: return len(available_devices()) +def device_count() -> int: + return len(available_devices()) # equivocal setattr to save one lookup per assignment _object_setattr = object.__setattr__ def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None: - """This makes sure that we don't overwrite any existing attributes on the object.""" - _setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj) - if not hasattr(obj, name): _setattr(name, value) + """This makes sure that we don't overwrite any existing attributes on the object.""" + _setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj) + if not hasattr(obj, name): _setattr(name, value) -def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str: return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key]))) +def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str: + return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key]))) # Special debug flag controled via OPENLLMDEVDEBUG DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.getenv(DEV_DEBUG_VAR))) @@ -125,210 +128,201 @@ DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.gete MYPY = False SHOW_CODEGEN = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3 -def get_debug_mode() -> bool: return DEBUG or _get_debug_mode() -def get_quiet_mode() -> bool: return not DEBUG and _get_quiet_mode() +def get_debug_mode() -> bool: + return DEBUG or _get_debug_mode() + +def get_quiet_mode() -> bool: + return not DEBUG and _get_quiet_mode() class ExceptionFilter(logging.Filter): - def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any): - """A filter of all exception.""" - if exclude_exceptions is None: exclude_exceptions = [ConflictError] - if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError) - super(ExceptionFilter, self).__init__(**kwargs) - self.EXCLUDE_EXCEPTIONS = exclude_exceptions - def filter(self, record: logging.LogRecord) -> bool: - if record.exc_info: - etype, _, _ = record.exc_info - if etype is not None: - for exc in self.EXCLUDE_EXCEPTIONS: - if issubclass(etype, exc): return False - return True + def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any): + """A filter of all exception.""" + if exclude_exceptions is None: exclude_exceptions = [ConflictError] + if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError) + super(ExceptionFilter, self).__init__(**kwargs) + self.EXCLUDE_EXCEPTIONS = exclude_exceptions + + def filter(self, record: logging.LogRecord) -> bool: + if record.exc_info: + etype, _, _ = record.exc_info + if etype is not None: + for exc in self.EXCLUDE_EXCEPTIONS: + if issubclass(etype, exc): return False + return True class InfoFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING + def filter(self, record: logging.LogRecord) -> bool: + return logging.INFO <= record.levelno < logging.WARNING _LOGGING_CONFIG: DictStrAny = { - "version": 1, - "disable_existing_loggers": True, - "filters": {"excfilter": {"()": "openllm.utils.ExceptionFilter"}, "infofilter": {"()": "openllm.utils.InfoFilter"}}, - "handlers": { - "bentomlhandler": { - "class": "logging.StreamHandler", - "filters": ["excfilter", "infofilter"], - "stream": "ext://sys.stdout", - }, - "defaulthandler": { - "class": "logging.StreamHandler", - "level": logging.WARNING, - }, - }, - "loggers": { - "bentoml": { - "handlers": ["bentomlhandler", "defaulthandler"], - "level": logging.INFO, - "propagate": False, - }, - "openllm": { - "handlers": ["bentomlhandler", "defaulthandler"], - "level": logging.INFO, - "propagate": False, - }, - }, - "root": {"level": logging.WARNING}, + "version": 1, "disable_existing_loggers": True, "filters": {"excfilter": {"()": "openllm.utils.ExceptionFilter"}, "infofilter": {"()": "openllm.utils.InfoFilter"}}, "handlers": { + "bentomlhandler": {"class": "logging.StreamHandler", "filters": ["excfilter", "infofilter"], "stream": "ext://sys.stdout",}, "defaulthandler": {"class": "logging.StreamHandler", "level": logging.WARNING,}, + }, "loggers": {"bentoml": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False,}, "openllm": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False,},}, "root": {"level": logging.WARNING}, } def configure_logging() -> None: - """Configure logging for OpenLLM. + """Configure logging for OpenLLM. - Behaves similar to how BentoML loggers are being configured. - """ - if get_quiet_mode(): - _LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR - _LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR - _LOGGING_CONFIG["root"]["level"] = logging.ERROR - elif get_debug_mode() or DEBUG: - _LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG - _LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG - _LOGGING_CONFIG["root"]["level"] = logging.DEBUG - else: - _LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO - _LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO - _LOGGING_CONFIG["root"]["level"] = logging.INFO + Behaves similar to how BentoML loggers are being configured. + """ + if get_quiet_mode(): + _LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR + _LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR + _LOGGING_CONFIG["root"]["level"] = logging.ERROR + elif get_debug_mode() or DEBUG: + _LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG + _LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG + _LOGGING_CONFIG["root"]["level"] = logging.DEBUG + else: + _LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO + _LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO + _LOGGING_CONFIG["root"]["level"] = logging.INFO - logging.config.dictConfig(_LOGGING_CONFIG) + logging.config.dictConfig(_LOGGING_CONFIG) @functools.lru_cache(maxsize=1) def in_notebook() -> bool: - try: - from IPython.core.getipython import get_ipython - if "IPKernelApp" not in get_ipython().config: return False - except ImportError: return False - except AttributeError: return False - return True + try: + from IPython.core.getipython import get_ipython + if "IPKernelApp" not in get_ipython().config: return False + except ImportError: + return False + except AttributeError: + return False + return True _dockerenv, _cgroup = Path("/.dockerenv"), Path("/proc/self/cgroup") class suppress(contextlib.suppress, contextlib.ContextDecorator): - """A version of contextlib.suppress with decorator support. + """A version of contextlib.suppress with decorator support. - >>> @suppress(KeyError) - ... def key_error(): - ... {}[''] - >>> key_error() - """ + >>> @suppress(KeyError) + ... def key_error(): + ... {}[''] + >>> key_error() + """ def compose(*funcs: AnyCallable) -> AnyCallable: - """Compose any number of unary functions into a single unary function. + """Compose any number of unary functions into a single unary function. - >>> import textwrap - >>> expected = str.strip(textwrap.dedent(compose.__doc__)) - >>> strip_and_dedent = compose(str.strip, textwrap.dedent) - >>> strip_and_dedent(compose.__doc__) == expected - True + >>> import textwrap + >>> expected = str.strip(textwrap.dedent(compose.__doc__)) + >>> strip_and_dedent = compose(str.strip, textwrap.dedent) + >>> strip_and_dedent(compose.__doc__) == expected + True - Compose also allows the innermost function to take arbitrary arguments. + Compose also allows the innermost function to take arbitrary arguments. - >>> round_three = lambda x: round(x, ndigits=3) - >>> f = compose(round_three, int.__truediv__) - >>> [f(3*x, x+1) for x in range(1,10)] - [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] - """ - def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: return lambda *args, **kwargs: f1(f2(*args, **kwargs)) - return functools.reduce(compose_two, funcs) + >>> round_three = lambda x: round(x, ndigits=3) + >>> f = compose(round_three, int.__truediv__) + >>> [f(3*x, x+1) for x in range(1,10)] + [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] + """ + def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: + return lambda *args, **kwargs: f1(f2(*args, **kwargs)) + + return functools.reduce(compose_two, funcs) def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]: - """Decorate a function with a transform function that is invoked on results returned from the decorated function. + """Decorate a function with a transform function that is invoked on results returned from the decorated function. - ```python - @apply(reversed) - def get_numbers(start): - "doc for get_numbers" - return range(start, start+3) - list(get_numbers(4)) - # [6, 5, 4] - ``` - ```python - get_numbers.__doc__ - # 'doc for get_numbers' - ``` - """ - return lambda func: functools.wraps(func)(compose(transform, func)) + ```python + @apply(reversed) + def get_numbers(start): + "doc for get_numbers" + return range(start, start+3) + list(get_numbers(4)) + # [6, 5, 4] + ``` + ```python + get_numbers.__doc__ + # 'doc for get_numbers' + ``` + """ + return lambda func: functools.wraps(func)(compose(transform, func)) @apply(bool) @suppress(FileNotFoundError) -def _text_in_file(text: str, filename: Path) -> bool: return any(text in line for line in filename.open()) - +def _text_in_file(text: str, filename: Path) -> bool: + return any(text in line for line in filename.open()) def in_docker() -> bool: - """Is this current environment running in docker? + """Is this current environment running in docker? - ```python - type(in_docker()) - ``` - """ - return _dockerenv.exists() or _text_in_file("docker", _cgroup) + ```python + type(in_docker()) + ``` + """ + return _dockerenv.exists() or _text_in_file("docker", _cgroup) T, K = t.TypeVar("T"), t.TypeVar("K") def resolve_filepath(path: str, ctx: str | None = None) -> str: - """Resolve a file path to an absolute path, expand user and environment variables.""" - try: return resolve_user_filepath(path, ctx) - except FileNotFoundError: return path + """Resolve a file path to an absolute path, expand user and environment variables.""" + try: + return resolve_user_filepath(path, ctx) + except FileNotFoundError: + return path -def validate_is_path(maybe_path: str) -> bool: return os.path.exists(os.path.dirname(resolve_filepath(maybe_path))) +def validate_is_path(maybe_path: str) -> bool: + return os.path.exists(os.path.dirname(resolve_filepath(maybe_path))) def generate_context(framework_name: str) -> _ModelContext: - from .import_utils import is_flax_available - from .import_utils import is_tf_available - from .import_utils import is_torch_available + from .import_utils import is_flax_available + from .import_utils import is_tf_available + from .import_utils import is_torch_available - framework_versions = {"transformers": pkg.get_pkg_version("transformers")} - if is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch") - if is_tf_available(): - from bentoml._internal.frameworks.utils.tensorflow import get_tf_version - framework_versions["tensorflow"] = get_tf_version() - if is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")}) - return _ModelContext(framework_name=framework_name, framework_versions=framework_versions) + framework_versions = {"transformers": pkg.get_pkg_version("transformers")} + if is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch") + if is_tf_available(): + from bentoml._internal.frameworks.utils.tensorflow import get_tf_version + framework_versions["tensorflow"] = get_tf_version() + if is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")}) + return _ModelContext(framework_name=framework_name, framework_versions=framework_versions) def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> DictStrAny: - return { - "runtime": llm.runtime, - "framework": "openllm", - "model_name": llm.config["model_name"], - "architecture": llm.config["architecture"], - "serialisation_format": llm._serialisation_format, - } + return {"runtime": llm.runtime, "framework": "openllm", "model_name": llm.config["model_name"], "architecture": llm.config["architecture"], "serialisation_format": llm._serialisation_format,} _TOKENIZER_PREFIX = "_tokenizer_" + def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[DictStrAny, DictStrAny]: - """Normalize the given attrs to a model and tokenizer kwargs accordingly.""" - tokenizer_attrs = {k[len(_TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)} - for k in tuple(attrs.keys()): - if k.startswith(_TOKENIZER_PREFIX): del attrs[k] - return attrs, tokenizer_attrs + """Normalize the given attrs to a model and tokenizer kwargs accordingly.""" + tokenizer_attrs = {k[len(_TOKENIZER_PREFIX):]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)} + for k in tuple(attrs.keys()): + if k.startswith(_TOKENIZER_PREFIX): del attrs[k] + return attrs, tokenizer_attrs @_overload -def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]: ... +def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]: + ... + @_overload -def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]: ... +def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]: + ... + @_overload -def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]: ... +def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]: + ... + @_overload -def infer_auto_class(implementation: t.Literal["vllm"]) -> type[openllm.AutoVLLM]: ... +def infer_auto_class(implementation: t.Literal["vllm"]) -> type[openllm.AutoVLLM]: + ... + def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM] | type[openllm.AutoTFLLM] | type[openllm.AutoFlaxLLM] | type[openllm.AutoVLLM]: - if implementation == "tf": - from openllm import AutoTFLLM - return AutoTFLLM - elif implementation == "flax": - from openllm import AutoFlaxLLM - return AutoFlaxLLM - elif implementation == "pt": - from openllm import AutoLLM - return AutoLLM - elif implementation == "vllm": - from openllm import AutoVLLM - return AutoVLLM - else: raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')") - + if implementation == "tf": + from openllm import AutoTFLLM + return AutoTFLLM + elif implementation == "flax": + from openllm import AutoFlaxLLM + return AutoFlaxLLM + elif implementation == "pt": + from openllm import AutoLLM + return AutoLLM + elif implementation == "vllm": + from openllm import AutoVLLM + return AutoVLLM + else: + raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')") # NOTE: The set marks contains a set of modules name # that are available above and are whitelisted @@ -337,80 +331,52 @@ _whitelist_modules = {"pkg"} # XXX: define all classes, functions import above this line # since _extras will be the locals() import from this file. -_extras: dict[str, t.Any] = { - k: v - for k, v in locals().items() - if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_")) -} +_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_"))} _extras["__openllm_migration__"] = {"ModelEnv": "EnvVarMixin"} _import_structure: dict[str, list[str]] = { - "analytics": [], - "codegen": [], - "dantic": [], - "representation": ["ReprMixin"], - "lazy": ["LazyModule"], - "import_utils": [ - "OPTIONAL_DEPENDENCIES", - "ENV_VARS_TRUE_VALUES", - "DummyMetaclass", - "EnvVarMixin", - "requires_dependencies", - "is_cpm_kernels_available", - "is_einops_available", - "is_flax_available", - "is_tf_available", - "is_vllm_available", - "is_torch_available", - "is_bitsandbytes_available", - "is_peft_available", - "is_datasets_available", - "is_transformers_supports_kbit", - "is_transformers_supports_agent", - "is_jupyter_available", - "is_jupytext_available", - "is_notebook_available", - "is_triton_available", - "is_autogptq_available", - "require_backends", + "analytics": [], "codegen": [], "dantic": [], "representation": ["ReprMixin"], "lazy": ["LazyModule"], "import_utils": [ + "OPTIONAL_DEPENDENCIES", "ENV_VARS_TRUE_VALUES", "DummyMetaclass", "EnvVarMixin", "requires_dependencies", "is_cpm_kernels_available", "is_einops_available", "is_flax_available", "is_tf_available", "is_vllm_available", "is_torch_available", "is_bitsandbytes_available", "is_peft_available", "is_datasets_available", "is_transformers_supports_kbit", + "is_transformers_supports_agent", "is_jupyter_available", "is_jupytext_available", "is_notebook_available", "is_triton_available", "is_autogptq_available", "require_backends", ], } if t.TYPE_CHECKING: - # NOTE: The following exports useful utils from bentoml - from . import LazyLoader as LazyLoader - from . import LazyType as LazyType - from . import analytics as analytics - from . import bentoml_cattr as bentoml_cattr - from . import codegen as codegen - from . import configure_logging as configure_logging - from . import dantic as dantic - from . import first_not_none as first_not_none - from . import reserve_free_port as reserve_free_port - from . import set_quiet_mode as set_quiet_mode - from . import validate_is_path as validate_is_path - from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES - from .import_utils import OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES - from .import_utils import DummyMetaclass as DummyMetaclass - from .import_utils import EnvVarMixin as EnvVarMixin - from .import_utils import is_autogptq_available as is_autogptq_available - from .import_utils import is_bitsandbytes_available as is_bitsandbytes_available - from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available - from .import_utils import is_datasets_available as is_datasets_available - from .import_utils import is_einops_available as is_einops_available - from .import_utils import is_flax_available as is_flax_available - from .import_utils import is_jupyter_available as is_jupyter_available - from .import_utils import is_jupytext_available as is_jupytext_available - from .import_utils import is_notebook_available as is_notebook_available - from .import_utils import is_peft_available as is_peft_available - from .import_utils import is_tf_available as is_tf_available - from .import_utils import is_torch_available as is_torch_available - from .import_utils import is_transformers_supports_agent as is_transformers_supports_agent - from .import_utils import is_transformers_supports_kbit as is_transformers_supports_kbit - from .import_utils import is_triton_available as is_triton_available - from .import_utils import is_vllm_available as is_vllm_available - from .import_utils import require_backends as require_backends - from .import_utils import requires_dependencies as requires_dependencies - from .representation import ReprMixin as ReprMixin -else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects=_extras) + # NOTE: The following exports useful utils from bentoml + from . import LazyLoader as LazyLoader + from . import LazyType as LazyType + from . import analytics as analytics + from . import bentoml_cattr as bentoml_cattr + from . import codegen as codegen + from . import configure_logging as configure_logging + from . import dantic as dantic + from . import first_not_none as first_not_none + from . import reserve_free_port as reserve_free_port + from . import set_quiet_mode as set_quiet_mode + from . import validate_is_path as validate_is_path + from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES + from .import_utils import OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES + from .import_utils import DummyMetaclass as DummyMetaclass + from .import_utils import EnvVarMixin as EnvVarMixin + from .import_utils import is_autogptq_available as is_autogptq_available + from .import_utils import is_bitsandbytes_available as is_bitsandbytes_available + from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available + from .import_utils import is_datasets_available as is_datasets_available + from .import_utils import is_einops_available as is_einops_available + from .import_utils import is_flax_available as is_flax_available + from .import_utils import is_jupyter_available as is_jupyter_available + from .import_utils import is_jupytext_available as is_jupytext_available + from .import_utils import is_notebook_available as is_notebook_available + from .import_utils import is_peft_available as is_peft_available + from .import_utils import is_tf_available as is_tf_available + from .import_utils import is_torch_available as is_torch_available + from .import_utils import is_transformers_supports_agent as is_transformers_supports_agent + from .import_utils import is_transformers_supports_kbit as is_transformers_supports_kbit + from .import_utils import is_triton_available as is_triton_available + from .import_utils import is_vllm_available as is_vllm_available + from .import_utils import require_backends as require_backends + from .import_utils import requires_dependencies as requires_dependencies + from .representation import ReprMixin as ReprMixin +else: + sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects=_extras) diff --git a/src/openllm/utils/analytics.py b/src/openllm/utils/analytics.py index d85c6d32..8df0c678 100644 --- a/src/openllm/utils/analytics.py +++ b/src/openllm/utils/analytics.py @@ -30,12 +30,11 @@ import openllm from bentoml._internal.utils import analytics as _internal_analytics if t.TYPE_CHECKING: - from .._types import P - from .._types import T + from .._types import P + from .._types import T logger = logging.getLogger(__name__) - ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} # This variable is a proxy that will control BENTOML_DO_NOT_TRACK @@ -43,87 +42,78 @@ OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK" DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper() - @functools.lru_cache(maxsize=1) def do_not_track() -> bool: - return DO_NOT_TRACK in ENV_VARS_TRUE_VALUES - + return DO_NOT_TRACK in ENV_VARS_TRUE_VALUES @functools.lru_cache(maxsize=1) def _usage_event_debugging() -> bool: - # For BentoML developers only - debug and print event payload if turned on - return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true" - + # For BentoML developers only - debug and print event payload if turned on + return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true" def silent(func: t.Callable[P, T]) -> t.Callable[P, T]: - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any: - try: - return func(*args, **kwargs) - except Exception as err: - if _usage_event_debugging(): - if openllm.utils.get_debug_mode(): - logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3) - else: - logger.info("Tracking Error: %s", err) - else: - logger.debug("Tracking Error: %s", err) - - return wrapper + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any: + try: + return func(*args, **kwargs) + except Exception as err: + if _usage_event_debugging(): + if openllm.utils.get_debug_mode(): + logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3) + else: + logger.info("Tracking Error: %s", err) + else: + logger.debug("Tracking Error: %s", err) + return wrapper @silent def track(event_properties: attr.AttrsInstance) -> None: - if do_not_track(): - return - _internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties)) - + if do_not_track(): + return + _internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties)) @contextlib.contextmanager def set_bentoml_tracking() -> t.Generator[None, None, None]: - original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False)) - try: - os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track()) - yield - finally: - os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value - + original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False)) + try: + os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track()) + yield + finally: + os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value class EventMeta: - @property - def event_name(self) -> str: - # camel case to snake case - event_name = re.sub(r"(? str: + # camel case to snake case + event_name = re.sub(r"(? StartInitEvent: - return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump()) + model_name: str + llm_config: t.Dict[str, t.Any] = attr.field(default=None) + @staticmethod + def handler(llm_config: openllm.LLMConfig) -> StartInitEvent: + return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump()) def track_start_init(llm_config: openllm.LLMConfig) -> None: - if do_not_track(): - return - track(StartInitEvent.handler(llm_config)) + if do_not_track(): + return + track(StartInitEvent.handler(llm_config)) diff --git a/src/openllm/utils/codegen.py b/src/openllm/utils/codegen.py index 6fec8944..68b47020 100644 --- a/src/openllm/utils/codegen.py +++ b/src/openllm/utils/codegen.py @@ -26,19 +26,19 @@ from pathlib import Path import orjson if t.TYPE_CHECKING: - from fs.base import FS + from fs.base import FS - import openllm + import openllm - from .._types import AnyCallable - from .._types import DictStrAny - from .._types import ListStr + from .._types import AnyCallable + from .._types import DictStrAny + from .._types import ListStr - PartialAny = functools.partial[t.Any] + PartialAny = functools.partial[t.Any] else: - DictStrAny = dict - ListStr = list - PartialAny = functools.partial + DictStrAny = dict + ListStr = list + PartialAny = functools.partial _T = t.TypeVar("_T", bound=t.Callable[..., t.Any]) @@ -47,274 +47,206 @@ logger = logging.getLogger(__name__) OPENLLM_MODEL_NAME = "# openllm: model name" OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map" - class ModelNameFormatter(string.Formatter): - model_keyword: t.LiteralString = "__model_name__" + model_keyword: t.LiteralString = "__model_name__" - def __init__(self, model_name: str): - """The formatter that extends model_name to be formatted the 'service.py'.""" - super().__init__() - self.model_name = model_name + def __init__(self, model_name: str): + """The formatter that extends model_name to be formatted the 'service.py'.""" + super().__init__() + self.model_name = model_name - def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any: - return super().vformat(format_string, (), {self.model_keyword: self.model_name}) - - def can_format(self, value: str) -> bool: - try: - self.parse(value) - return True - except ValueError: - return False + def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any: + return super().vformat(format_string, (), {self.model_keyword: self.model_name}) + def can_format(self, value: str) -> bool: + try: + self.parse(value) + return True + except ValueError: + return False class ModelIdFormatter(ModelNameFormatter): - model_keyword: t.LiteralString = "__model_id__" - + model_keyword: t.LiteralString = "__model_id__" class ModelAdapterMapFormatter(ModelNameFormatter): - model_keyword: t.LiteralString = "__model_adapter_map__" - + model_keyword: t.LiteralString = "__model_adapter_map__" _service_file = Path(__file__).parent.parent / "_service.py" - def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None: - from . import DEBUG + from . import DEBUG - model_name = llm.config["model_name"] + model_name = llm.config["model_name"] - logger.debug("Generating service for %s", model_name) + logger.debug("Generating service for %s", model_name) - with open(_service_file.__fspath__(), "r") as f: - src_contents = f.readlines() + with open(_service_file.__fspath__(), "r") as f: + src_contents = f.readlines() - # modify with model name - for it in src_contents: - if OPENLLM_MODEL_NAME in it: - src_contents[src_contents.index(it)] = ( - ModelNameFormatter(model_name).vformat(it)[: -(len(OPENLLM_MODEL_NAME) + 3)] + "\n" - ) - elif OPENLLM_MODEL_ADAPTER_MAP in it: - src_contents[src_contents.index(it)] = ( - ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[ - : -(len(OPENLLM_MODEL_ADAPTER_MAP) + 3) - ] - + "\n" - ) + # modify with model name + for it in src_contents: + if OPENLLM_MODEL_NAME in it: + src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + "\n") + elif OPENLLM_MODEL_ADAPTER_MAP in it: + src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + "\n") - script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents) + script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents) - if DEBUG: - logger.info("Generated script:\n%s", script) - - llm_fs.writetext(llm.config["service_name"], script) + if DEBUG: + logger.info("Generated script:\n%s", script) + llm_fs.writetext(llm.config["service_name"], script) # NOTE: The following ins extracted from attrs internal APIs # sentinel object for unequivocal object() getattr _sentinel = object() - def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool: - """Check whether *cls* defines *attrib_name* (and doesn't just inherit it).""" - attr = getattr(cls, attrib_name, _sentinel) - if attr is _sentinel: - return False + """Check whether *cls* defines *attrib_name* (and doesn't just inherit it).""" + attr = getattr(cls, attrib_name, _sentinel) + if attr is _sentinel: + return False - for base_cls in cls.__mro__[1:]: - a = getattr(base_cls, attrib_name, None) - if attr is a: - return False - - return True + for base_cls in cls.__mro__[1:]: + a = getattr(base_cls, attrib_name, None) + if attr is a: + return False + return True def get_annotations(cls: type[t.Any]) -> DictStrAny: - """Get annotations for *cls*.""" - if has_own_attribute(cls, "__annotations__"): - return cls.__annotations__ + """Get annotations for *cls*.""" + if has_own_attribute(cls, "__annotations__"): + return cls.__annotations__ - return DictStrAny() - - -_classvar_prefixes = ( - "typing.ClassVar", - "t.ClassVar", - "ClassVar", - "typing_extensions.ClassVar", -) + return DictStrAny() +_classvar_prefixes = ("typing.ClassVar", "t.ClassVar", "ClassVar", "typing_extensions.ClassVar",) def is_class_var(annot: str | t.Any) -> bool: - """Check whether *annot* is a typing.ClassVar. + """Check whether *annot* is a typing.ClassVar. - The string comparison hack is used to avoid evaluating all string - annotations which would put attrs-based classes at a performance - disadvantage compared to plain old classes. - """ - annot = str(annot) + The string comparison hack is used to avoid evaluating all string + annotations which would put attrs-based classes at a performance + disadvantage compared to plain old classes. + """ + annot = str(annot) - # Annotation can be quoted. - if annot.startswith(("'", '"')) and annot.endswith(("'", '"')): - annot = annot[1:-1] - - return annot.startswith(_classvar_prefixes) + # Annotation can be quoted. + if annot.startswith(("'", '"')) and annot.endswith(("'", '"')): + annot = annot[1:-1] + return annot.startswith(_classvar_prefixes) def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T: - """Add __module__ and __qualname__ to a *method* if possible.""" - try: method_or_cls.__module__ = cls.__module__ - except AttributeError: pass - try: method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}" - except AttributeError: pass - try: method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}." - except AttributeError: pass - return method_or_cls - + """Add __module__ and __qualname__ to a *method* if possible.""" + try: + method_or_cls.__module__ = cls.__module__ + except AttributeError: + pass + try: + method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}" + except AttributeError: + pass + try: + method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}." + except AttributeError: + pass + return method_or_cls # Exec the script with the given global (globs) and local (locs) variables. -def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None: eval(compile(script, filename, "exec"), globs, locs) # noqa: S307 +def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None: + eval(compile(script, filename, "exec"), globs, locs) # noqa: S307 # ported from attrs def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable: - """Create the method with the script given and return the method object.""" - locs: DictStrAny = {} - - # In order of debuggers like PDB being able to step through the code, - # we add a fake linecache entry. - count = 1 - base_filename = filename - while True: - linecache_tuple = (len(script), None, script.splitlines(True), filename) - old_val = linecache.cache.setdefault(filename, linecache_tuple) - if old_val == linecache_tuple: break - else: - filename = f"{base_filename[:-1]}-{count}>" - count += 1 - _compile_and_eval(script, globs, locs, filename) - return locs[name] + """Create the method with the script given and return the method object.""" + locs: DictStrAny = {} + # In order of debuggers like PDB being able to step through the code, + # we add a fake linecache entry. + count = 1 + base_filename = filename + while True: + linecache_tuple = (len(script), None, script.splitlines(True), filename) + old_val = linecache.cache.setdefault(filename, linecache_tuple) + if old_val == linecache_tuple: break + else: + filename = f"{base_filename[:-1]}-{count}>" + count += 1 + _compile_and_eval(script, globs, locs, filename) + return locs[name] def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]: - """Create a tuple subclass to hold class attributes. + """Create a tuple subclass to hold class attributes. - The subclass is a bare tuple with properties for names. + The subclass is a bare tuple with properties for names. - class MyClassAttributes(tuple): - __slots__ = () - x = property(itemgetter(0)) - """ - from . import SHOW_CODEGEN + class MyClassAttributes(tuple): + __slots__ = () + x = property(itemgetter(0)) + """ + from . import SHOW_CODEGEN - attr_class_name = f"{cls_name}Attributes" - attr_class_template = [ - f"class {attr_class_name}(tuple):", - " __slots__ = ()", - ] - if attr_names: - for i, attr_name in enumerate(attr_names): attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))") - else: attr_class_template.append(" pass") - globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property} - if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template)) - _compile_and_eval("\n".join(attr_class_template), globs) - return globs[attr_class_name] + attr_class_name = f"{cls_name}Attributes" + attr_class_template = [f"class {attr_class_name}(tuple):", " __slots__ = ()",] + if attr_names: + for i, attr_name in enumerate(attr_names): + attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))") + else: + attr_class_template.append(" pass") + globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property} + if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template)) + _compile_and_eval("\n".join(attr_class_template), globs) + return globs[attr_class_name] +def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: + return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>" -def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>" +def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None,) -> AnyCallable: + from . import SHOW_CODEGEN + script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass") + meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs) + if annotations: meth.__annotations__ = annotations + if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script) -def generate_function( - typ: type[t.Any], - func_name: str, - lines: list[str] | None, - args: tuple[str, ...] | None, - globs: dict[str, t.Any], - annotations: dict[str, t.Any] | None = None, -) -> AnyCallable: - from . import SHOW_CODEGEN + return meth - script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass") - meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs) - if annotations: meth.__annotations__ = annotations - if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script) +def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: t.LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable: + from . import dantic + from . import field_env_key - return meth + def identity(_: str, x_value: t.Any) -> t.Any: + return x_value + default_callback = identity if default_callback is None else default_callback -def make_env_transformer( - cls: type[openllm.LLMConfig], - model_name: str, - suffix: t.LiteralString | None = None, - default_callback: t.Callable[[str, t.Any], t.Any] | None = None, - globs: DictStrAny | None = None, -) -> AnyCallable: - from . import dantic - from . import field_env_key + globs = {} if globs is None else globs + globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,}) - def identity(_: str, x_value: t.Any) -> t.Any: - return x_value + lines: ListStr = [ + "__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", "return [", " f.evolve(", " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", " metadata={", " 'env': f.metadata.get('env', __env(f.name)),", " 'description': f.metadata.get('description', '(not provided)'),", " },", + " )", " for f in fields", "]", + ] + fields_ann = "list[attr.Attribute[t.Any]]" - default_callback = identity if default_callback is None else default_callback - - globs = {} if globs is None else globs - globs.update( - { - "__populate_env": dantic.env_converter, - "__default_callback": default_callback, - "__field_env": field_env_key, - "__suffix": suffix or "", - "__model_name": model_name, - } - ) - - lines: ListStr = [ - "__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", - "return [", - " f.evolve(", - " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", - " metadata={", - " 'env': f.metadata.get('env', __env(f.name)),", - " 'description': f.metadata.get('description', '(not provided)'),", - " },", - " )", - " for f in fields", - "]", - ] - fields_ann = "list[attr.Attribute[t.Any]]" - - return generate_function( - cls, - "__auto_env", - lines, - args=("_", "fields"), - globs=globs, - annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann}, - ) + return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann},) def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T: - """Enhance sdk with nice repr that plays well with your brain.""" - from .representation import ReprMixin + """Enhance sdk with nice repr that plays well with your brain.""" + from .representation import ReprMixin - if name is None: name = func.__name__.strip("_") - _signatures = inspect.signature(func).parameters - def _repr(self: ReprMixin) -> str: return f"" - def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: return ((k, _signatures[k].annotation) for k in self.__repr_keys__) - if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}" - else: doc = func.__doc__ - return t.cast(_T, functools.update_wrapper( - types.new_class( - name, - (PartialAny, ReprMixin), - exec_body=lambda ns: ns.update( - { - "__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), - "__repr_args__": _repr_args, - "__repr__": _repr, - "__doc__": inspect.cleandoc(doc), - "__module__": "openllm", - } - ), - )(func, **attrs), - func, - )) + if name is None: name = func.__name__.strip("_") + _signatures = inspect.signature(func).parameters + + def _repr(self: ReprMixin) -> str: + return f"" + + def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: + return ((k, _signatures[k].annotation) for k in self.__repr_keys__) + + if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}" + else: doc = func.__doc__ + return t.cast(_T, functools.update_wrapper(types.new_class(name, (PartialAny, ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,)) diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index 3fa3c4b7..e685636d 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -31,478 +31,432 @@ from click import shell_completion as sc from click import types as click_types if t.TYPE_CHECKING: - from attr import _ValidatorType + from attr import _ValidatorType - from .._types import ListAny + from .._types import ListAny _T = t.TypeVar("_T") AnyCallable = t.Callable[..., t.Any] FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command]) +def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: type[t.Any] | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]: + # TODO: support parsing nested attrs class and Union + envvar = field.metadata["env"] + dasherized = inflection.dasherize(name) + underscored = inflection.underscore(name) -def attrs_to_options( - name: str, - field: attr.Attribute[t.Any], - model_name: str, - typ: type[t.Any] | None = None, - suffix_generation: bool = False, - suffix_sampling: bool = False, -) -> t.Callable[[FC], FC]: - # TODO: support parsing nested attrs class and Union - envvar = field.metadata["env"] - dasherized = inflection.dasherize(name) - underscored = inflection.underscore(name) + if typ in (None, attr.NOTHING): + typ = field.type + if typ is None: raise RuntimeError(f"Failed to parse type for {name}") - if typ in (None, attr.NOTHING): - typ = field.type - if typ is None: raise RuntimeError(f"Failed to parse type for {name}") - - full_option_name = f"--{dasherized}" - if field.type is bool: full_option_name += f"/--no-{dasherized}" - if suffix_generation: identifier = f"{model_name}_generation_{underscored}" - elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}" - else: identifier = f"{model_name}_{underscored}" - - return cog.optgroup.option( - identifier, - full_option_name, - type=parse_type(typ), - required=field.default is attr.NOTHING, - default=field.default if field.default not in (attr.NOTHING, None) else None, - show_default=True, - multiple=allows_multiple(typ) if typ else False, - help=field.metadata.get("description", "(No description provided)"), - show_envvar=True, - envvar=envvar, - ) + full_option_name = f"--{dasherized}" + if field.type is bool: full_option_name += f"/--no-{dasherized}" + if suffix_generation: identifier = f"{model_name}_generation_{underscored}" + elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}" + else: identifier = f"{model_name}_{underscored}" + return cog.optgroup.option(identifier, full_option_name, type=parse_type(typ), required=field.default is attr.NOTHING, default=field.default if field.default not in (attr.NOTHING, None) else None, show_default=True, multiple=allows_multiple(typ) if typ else False, help=field.metadata.get("description", "(No description provided)"), show_envvar=True, envvar=envvar,) def env_converter(value: t.Any, env: str | None = None) -> t.Any: - if env is not None: - value = os.environ.get(env, value) - if value is not None and isinstance(value, str): - try: - return orjson.loads(value.lower()) - except orjson.JSONDecodeError as err: - raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None - return value + if env is not None: + value = os.environ.get(env, value) + if value is not None and isinstance(value, str): + try: + return orjson.loads(value.lower()) + except orjson.JSONDecodeError as err: + raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None + return value +def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[_T] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any,) -> t.Any: + """A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field. -def Field( - default: t.Any = None, - *, - ge: int | float | None = None, - le: int | float | None = None, - validator: _ValidatorType[_T] | None = None, - description: str | None = None, - env: str | None = None, - auto_default: bool = False, - use_default_converter: bool = True, - **attrs: t.Any, -) -> t.Any: - """A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field. + By default, if both validator and ge are provided, then then ge will be + piped into first, then all of the other validator will be run afterwards. - By default, if both validator and ge are provided, then then ge will be - piped into first, then all of the other validator will be run afterwards. + Args: + default: The default value for ``dantic.Field``. Defaults to ``None``. + ge: Greater than or equal to. Defaults to None. + le: Less than or equal to. Defaults to None. + validator: Optional attrs-compatible validators type. Default to None + description: the documentation for the field. Defaults to None. + env: the environment variable to read from. Defaults to None. + auto_default: a bool indicating whether to use the default value as the environment. + Defaults to False. If set to True, the behaviour of this Field will also depends + on kw_only. If kw_only=True, the this field will become 'Required' and the default + value is omitted. If kw_only=False, then the default value will be used as before. + use_default_converter: a bool indicating whether to use the default converter. Defaults + to True. If set to False, then the default converter will not be used. + The default converter converts a given value from the environment variable + for this given Field. + **attrs: The rest of the arguments are passed to attr.field + """ + metadata = attrs.pop("metadata", {}) + if description is None: + description = "(No description provided)" + metadata["description"] = description + if env is not None: + metadata["env"] = env + piped: list[_ValidatorType[t.Any]] = [] - Args: - default: The default value for ``dantic.Field``. Defaults to ``None``. - ge: Greater than or equal to. Defaults to None. - le: Less than or equal to. Defaults to None. - validator: Optional attrs-compatible validators type. Default to None - description: the documentation for the field. Defaults to None. - env: the environment variable to read from. Defaults to None. - auto_default: a bool indicating whether to use the default value as the environment. - Defaults to False. If set to True, the behaviour of this Field will also depends - on kw_only. If kw_only=True, the this field will become 'Required' and the default - value is omitted. If kw_only=False, then the default value will be used as before. - use_default_converter: a bool indicating whether to use the default converter. Defaults - to True. If set to False, then the default converter will not be used. - The default converter converts a given value from the environment variable - for this given Field. - **attrs: The rest of the arguments are passed to attr.field - """ - metadata = attrs.pop("metadata", {}) - if description is None: - description = "(No description provided)" - metadata["description"] = description - if env is not None: - metadata["env"] = env - piped: list[_ValidatorType[t.Any]] = [] + converter = attrs.pop("converter", None) + if use_default_converter: + converter = functools.partial(env_converter, env=env) - converter = attrs.pop("converter", None) - if use_default_converter: - converter = functools.partial(env_converter, env=env) + if ge is not None: + piped.append(attr.validators.ge(ge)) + if le is not None: + piped.append(attr.validators.le(le)) + if validator is not None: + piped.append(validator) - if ge is not None: - piped.append(attr.validators.ge(ge)) - if le is not None: - piped.append(attr.validators.le(le)) - if validator is not None: - piped.append(validator) + if len(piped) == 0: + _validator = None + elif len(piped) == 1: + _validator = piped[0] + else: + _validator = attr.validators.and_(*piped) - if len(piped) == 0: - _validator = None - elif len(piped) == 1: - _validator = piped[0] - else: - _validator = attr.validators.and_(*piped) + factory = attrs.pop("factory", None) + if factory is not None and default is not None: + raise RuntimeError("'factory' and 'default' are mutually exclusive.") + # NOTE: the behaviour of this is we will respect factory over the default + if factory is not None: + attrs["factory"] = factory + else: + attrs["default"] = default - factory = attrs.pop("factory", None) - if factory is not None and default is not None: - raise RuntimeError("'factory' and 'default' are mutually exclusive.") - # NOTE: the behaviour of this is we will respect factory over the default - if factory is not None: - attrs["factory"] = factory - else: - attrs["default"] = default - - kw_only = attrs.pop("kw_only", False) - if auto_default and kw_only: - attrs.pop("default") - - return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs) + kw_only = attrs.pop("kw_only", False) + if auto_default and kw_only: + attrs.pop("default") + return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs) def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType]: - """Transforms the pydantic field's type into a click-compatible type. + """Transforms the pydantic field's type into a click-compatible type. - Args: - field_type: pydantic field type + Args: + field_type: pydantic field type - Returns: - ParamType: click type equivalent - """ - from . import lenient_issubclass - - if t.get_origin(field_type) is t.Union: - raise NotImplementedError("Unions are not supported") - # enumeration strings or other Enum derivatives - if lenient_issubclass(field_type, Enum): - return EnumChoice(enum=field_type, case_sensitive=True) - # literals are enum-like with way less functionality - if is_literal(field_type): - return LiteralChoice(value=field_type, case_sensitive=True) - # modules, classes, functions - if is_typing(field_type): - return ModuleType() - # entire dictionaries: - # using a Dict, convert in advance - if is_mapping(field_type): - return JsonType() - # list, List[p], Tuple[p], Set[p] and so on - if is_container(field_type): - return parse_container_args(field_type) - # bytes are not natively supported by click - if lenient_issubclass(field_type, bytes): - return BytesType() - # return the current type: it should be a primitive - return field_type + Returns: + ParamType: click type equivalent + """ + from . import lenient_issubclass + if t.get_origin(field_type) is t.Union: + raise NotImplementedError("Unions are not supported") + # enumeration strings or other Enum derivatives + if lenient_issubclass(field_type, Enum): + return EnumChoice(enum=field_type, case_sensitive=True) + # literals are enum-like with way less functionality + if is_literal(field_type): + return LiteralChoice(value=field_type, case_sensitive=True) + # modules, classes, functions + if is_typing(field_type): + return ModuleType() + # entire dictionaries: + # using a Dict, convert in advance + if is_mapping(field_type): + return JsonType() + # list, List[p], Tuple[p], Set[p] and so on + if is_container(field_type): + return parse_container_args(field_type) + # bytes are not natively supported by click + if lenient_issubclass(field_type, bytes): + return BytesType() + # return the current type: it should be a primitive + return field_type def is_typing(field_type: type) -> bool: - """Checks whether the current type is a module-like type. + """Checks whether the current type is a module-like type. - Args: - field_type: pydantic field type + Args: + field_type: pydantic field type - Returns: - bool: true if the type is itself a type - """ - raw = t.get_origin(field_type) - if raw is None: - return False - if raw is type or raw is t.Type: - return True + Returns: + bool: true if the type is itself a type + """ + raw = t.get_origin(field_type) + if raw is None: return False - + if raw is type or raw is t.Type: + return True + return False def is_literal(field_type: type) -> bool: - """Checks whether the given field type is a Literal type or not. + """Checks whether the given field type is a Literal type or not. - Literals are weird: isinstance and subclass do not work, so you compare - the origin with the Literal declaration itself. + Literals are weird: isinstance and subclass do not work, so you compare + the origin with the Literal declaration itself. - Args: - field_type: current pydantic type - - Returns: - bool: true if Literal type, false otherwise - """ - origin = t.get_origin(field_type) - return origin is not None and origin is t.Literal + Args: + field_type: current pydantic type + Returns: + bool: true if Literal type, false otherwise + """ + origin = t.get_origin(field_type) + return origin is not None and origin is t.Literal class ModuleType(ParamType): - name = "module" + name = "module" - def _import_object(self, value: str) -> t.Any: - module_name, class_name = value.rsplit(".", maxsplit=1) - if not all(s.isidentifier() for s in module_name.split(".")): - raise ValueError(f"'{value}' is not a valid module name") - if not class_name.isidentifier(): - raise ValueError(f"Variable '{class_name}' is not a valid identifier") + def _import_object(self, value: str) -> t.Any: + module_name, class_name = value.rsplit(".", maxsplit=1) + if not all(s.isidentifier() for s in module_name.split(".")): + raise ValueError(f"'{value}' is not a valid module name") + if not class_name.isidentifier(): + raise ValueError(f"Variable '{class_name}' is not a valid identifier") - module = importlib.import_module(module_name) - if class_name: - try: - return getattr(module, class_name) - except AttributeError: - raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None - - def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: - try: - if isinstance(value, str): - return self._import_object(value) - return value - except Exception as exc: - self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx) + module = importlib.import_module(module_name) + if class_name: + try: + return getattr(module, class_name) + except AttributeError: + raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None + def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: + try: + if isinstance(value, str): + return self._import_object(value) + return value + except Exception as exc: + self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx) class EnumChoice(click.Choice): - name = "enum" + name = "enum" - def __init__(self, enum: Enum, case_sensitive: bool = False): - """Enum type support for click that extends ``click.Choice``. + def __init__(self, enum: Enum, case_sensitive: bool = False): + """Enum type support for click that extends ``click.Choice``. - Args: - enum: Given enum - case_sensitive: Whether this choice should be case case_sensitive. - """ - self.mapping = enum - self.internal_type = type(enum) - choices: ListAny = [e.name for e in enum.__class__] - super().__init__(choices, case_sensitive) - - def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum: - if isinstance(value, self.internal_type): - return value - result = super().convert(value, param, ctx) - if isinstance(result, str): - result = self.internal_type[result] - return result + Args: + enum: Given enum + case_sensitive: Whether this choice should be case case_sensitive. + """ + self.mapping = enum + self.internal_type = type(enum) + choices: ListAny = [e.name for e in enum.__class__] + super().__init__(choices, case_sensitive) + def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum: + if isinstance(value, self.internal_type): + return value + result = super().convert(value, param, ctx) + if isinstance(result, str): + result = self.internal_type[result] + return result class LiteralChoice(EnumChoice): - name = "literal" - - def __init__(self, value: t.Any, case_sensitive: bool = False): - """Literal support for click.""" - # expect every literal value to belong to the same primitive type - values = list(value.__args__) - item_type = type(values[0]) - if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.") - _mapping = {str(v): v for v in values} - super(EnumChoice, self).__init__(list(_mapping), case_sensitive) - self.internal_type = item_type + name = "literal" + def __init__(self, value: t.Any, case_sensitive: bool = False): + """Literal support for click.""" + # expect every literal value to belong to the same primitive type + values = list(value.__args__) + item_type = type(values[0]) + if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.") + _mapping = {str(v): v for v in values} + super(EnumChoice, self).__init__(list(_mapping), case_sensitive) + self.internal_type = item_type def allows_multiple(field_type: type[t.Any]) -> bool: - """Checks whether the current type allows for multiple arguments to be provided as input or not. + """Checks whether the current type allows for multiple arguments to be provided as input or not. - For containers, it exploits click's support for lists and such to use the same option multiple times - to create a complex object: `python run.py --subsets train --subsets test` - # becomes `subsets: ["train", "test"]`. + For containers, it exploits click's support for lists and such to use the same option multiple times + to create a complex object: `python run.py --subsets train --subsets test` + # becomes `subsets: ["train", "test"]`. - Args: - field_type: pydantic type. + Args: + field_type: pydantic type. - Returns: - bool: true if it's a composite field (lists, containers and so on), false otherwise - """ - # Early out for mappings, since it's better to deal with them using strings. - if is_mapping(field_type): - return False - # Activate multiple option for (simple) container types - if is_container(field_type): - args = parse_container_args(field_type) - # A non-composite type has a single argument, such as 'List[int]' - # A composite type has a tuple of arguments, like 'Tuple[str, int, int]'. - # For the moment, only non-composite types are allowed. - return not isinstance(args, tuple) + Returns: + bool: true if it's a composite field (lists, containers and so on), false otherwise + """ + # Early out for mappings, since it's better to deal with them using strings. + if is_mapping(field_type): return False - + # Activate multiple option for (simple) container types + if is_container(field_type): + args = parse_container_args(field_type) + # A non-composite type has a single argument, such as 'List[int]' + # A composite type has a tuple of arguments, like 'Tuple[str, int, int]'. + # For the moment, only non-composite types are allowed. + return not isinstance(args, tuple) + return False def is_mapping(field_type: type) -> bool: - """Checks whether this field represents a dictionary or JSON object. + """Checks whether this field represents a dictionary or JSON object. - Args: - field_type (type): pydantic type - - Returns: - bool: true when the field is a dict-like object, false otherwise. - """ - # Early out for standard containers. - from . import lenient_issubclass - if lenient_issubclass(field_type, t.Mapping): return True - # for everything else or when the typing is more complex, check its origin - origin = t.get_origin(field_type) - if origin is None: return False - return lenient_issubclass(origin, t.Mapping) + Args: + field_type (type): pydantic type + Returns: + bool: true when the field is a dict-like object, false otherwise. + """ + # Early out for standard containers. + from . import lenient_issubclass + if lenient_issubclass(field_type, t.Mapping): return True + # for everything else or when the typing is more complex, check its origin + origin = t.get_origin(field_type) + if origin is None: return False + return lenient_issubclass(origin, t.Mapping) def is_container(field_type: type) -> bool: - """Checks whether the current type is a container type ('contains' other types), like lists and tuples. + """Checks whether the current type is a container type ('contains' other types), like lists and tuples. - Args: - field_type: pydantic field type - - Returns: - bool: true if a container, false otherwise - """ - # do not consider strings or byte arrays as containers - if field_type in (str, bytes): return False - # Early out for standard containers: list, tuple, range - from . import lenient_issubclass - if lenient_issubclass(field_type, t.Container): return True - origin = t.get_origin(field_type) - # Early out for non-typing objects - if origin is None: return False - return lenient_issubclass(origin, t.Container) + Args: + field_type: pydantic field type + Returns: + bool: true if a container, false otherwise + """ + # do not consider strings or byte arrays as containers + if field_type in (str, bytes): return False + # Early out for standard containers: list, tuple, range + from . import lenient_issubclass + if lenient_issubclass(field_type, t.Container): return True + origin = t.get_origin(field_type) + # Early out for non-typing objects + if origin is None: return False + return lenient_issubclass(origin, t.Container) def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType]: - """Parses the arguments inside a container type (lists, tuples and so on). + """Parses the arguments inside a container type (lists, tuples and so on). - Args: - field_type: pydantic field type - - Returns: - ParamType | tuple[ParamType]: single click-compatible type or a tuple - """ - if not is_container(field_type): - raise ValueError("Field type is not a container type.") - args = t.get_args(field_type) - # Early out for untyped containers: standard lists, tuples, List[Any] - # Use strings when the type is unknown, avoid click's type guessing - if len(args) == 0: - return click_types.convert_type(str) - # Early out for homogenous containers: Tuple[int], List[str] - if len(args) == 1: - return parse_single_arg(args[0]) - # Early out for homogenous tuples of indefinite length: Tuple[int, ...] - if len(args) == 2 and args[1] is Ellipsis: - return parse_single_arg(args[0]) - # Then deal with fixed-length containers: Tuple[str, int, int] - return tuple(parse_single_arg(arg) for arg in args) + Args: + field_type: pydantic field type + Returns: + ParamType | tuple[ParamType]: single click-compatible type or a tuple + """ + if not is_container(field_type): + raise ValueError("Field type is not a container type.") + args = t.get_args(field_type) + # Early out for untyped containers: standard lists, tuples, List[Any] + # Use strings when the type is unknown, avoid click's type guessing + if len(args) == 0: + return click_types.convert_type(str) + # Early out for homogenous containers: Tuple[int], List[str] + if len(args) == 1: + return parse_single_arg(args[0]) + # Early out for homogenous tuples of indefinite length: Tuple[int, ...] + if len(args) == 2 and args[1] is Ellipsis: + return parse_single_arg(args[0]) + # Then deal with fixed-length containers: Tuple[str, int, int] + return tuple(parse_single_arg(arg) for arg in args) def parse_single_arg(arg: type) -> ParamType: - """Returns the click-compatible type for container origin types. + """Returns the click-compatible type for container origin types. - In this case, returns string when it's not inferrable, a JSON for mappings - and the original type itself in every other case (ints, floats and so on). - Bytes is a special case, not natively handled by click. + In this case, returns string when it's not inferrable, a JSON for mappings + and the original type itself in every other case (ints, floats and so on). + Bytes is a special case, not natively handled by click. - Args: - arg (type): single argument - - Returns: - ParamType: click-compatible type - """ - from . import lenient_issubclass - # When we don't know the type, we choose 'str' - if arg is t.Any: return click_types.convert_type(str) - # For containers and nested models, we use JSON - if is_container(arg): return JsonType() - if lenient_issubclass(arg, bytes): return BytesType() - return click_types.convert_type(arg) + Args: + arg (type): single argument + Returns: + ParamType: click-compatible type + """ + from . import lenient_issubclass + # When we don't know the type, we choose 'str' + if arg is t.Any: return click_types.convert_type(str) + # For containers and nested models, we use JSON + if is_container(arg): return JsonType() + if lenient_issubclass(arg, bytes): return BytesType() + return click_types.convert_type(arg) class BytesType(ParamType): - name = "bytes" - - def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: - if isinstance(value, bytes): - return value - try: - return str.encode(value) - except Exception as exc: - self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx) + name = "bytes" + def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: + if isinstance(value, bytes): + return value + try: + return str.encode(value) + except Exception as exc: + self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx) CYGWIN = sys.platform.startswith("cygwin") WIN = sys.platform.startswith("win") if sys.platform.startswith("win") and WIN: - def _get_argv_encoding() -> str: - import locale + def _get_argv_encoding() -> str: + import locale - return locale.getpreferredencoding() + return locale.getpreferredencoding() else: - def _get_argv_encoding() -> str: - return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding() - + def _get_argv_encoding() -> str: + return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding() class CudaValueType(ParamType): - name = "cuda" - envvar_list_splitter = "," - is_composite = True - typ = click_types.convert_type(str) + name = "cuda" + envvar_list_splitter = "," + is_composite = True + typ = click_types.convert_type(str) - def split_envvar_value(self, rv: str) -> t.Sequence[str]: - var = tuple(i for i in rv.split(self.envvar_list_splitter)) - if "-1" in var: - return var[: var.index("-1")] - return var + def split_envvar_value(self, rv: str) -> t.Sequence[str]: + var = tuple(i for i in rv.split(self.envvar_list_splitter)) + if "-1" in var: + return var[:var.index("-1")] + return var - def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]: - """Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value. + def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]: + """Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value. - Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well. + Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well. - Args: - ctx: Invocation context for this command. - param: The parameter that is requesting completion. - incomplete: Value being completed. May be empty. - """ - from ..utils import available_devices + Args: + ctx: Invocation context for this command. + param: The parameter that is requesting completion. + incomplete: Value being completed. May be empty. + """ + from ..utils import available_devices - mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices() + mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices() - return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping] + return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping] - def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: - if isinstance(value, bytes): - enc = _get_argv_encoding() - try: - value = value.decode(enc) - except UnicodeError: - fs_enc = sys.getfilesystemencoding() - if fs_enc != enc: - try: - value = value.decode(fs_enc) - except UnicodeError: - value = value.decode("utf-8", "replace") - else: - value = value.decode("utf-8", "replace") + def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: + if isinstance(value, bytes): + enc = _get_argv_encoding() + try: + value = value.decode(enc) + except UnicodeError: + fs_enc = sys.getfilesystemencoding() + if fs_enc != enc: + try: + value = value.decode(fs_enc) + except UnicodeError: + value = value.decode("utf-8", "replace") + else: + value = value.decode("utf-8", "replace") - return tuple(self.typ(x, param, ctx) for x in value.split(",")) - - def __repr__(self) -> str: - """CUDA is a click.STRING extension.""" - return "STRING" + return tuple(self.typ(x, param, ctx) for x in value.split(",")) + def __repr__(self) -> str: + """CUDA is a click.STRING extension.""" + return "STRING" CUDA = CudaValueType() - class JsonType(ParamType): - name = "json" + name = "json" - def __init__(self, should_load: bool = True) -> None: - """Support JSON type for click.ParamType. + def __init__(self, should_load: bool = True) -> None: + """Support JSON type for click.ParamType. - Args: - should_load: Whether to load the JSON. Default to True. If False, the value won't be converted. - """ - super().__init__() - self.should_load = should_load + Args: + should_load: Whether to load the JSON. Default to True. If False, the value won't be converted. + """ + super().__init__() + self.should_load = should_load - def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: - from . import LazyType - if LazyType[t.Mapping[str, str]](t.Mapping[str, str]).isinstance(value) or not self.should_load: return value - try: return orjson.loads(value) - except orjson.JSONDecodeError as exc: self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx) + def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: + from . import LazyType + if LazyType[t.Mapping[str, str]](t.Mapping[str, str]).isinstance(value) or not self.should_load: return value + try: + return orjson.loads(value) + except orjson.JSONDecodeError as exc: + self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx) diff --git a/src/openllm/utils/dummy_flax_objects.py b/src/openllm/utils/dummy_flax_objects.py index c30717cc..36bb4595 100644 --- a/src/openllm/utils/dummy_flax_objects.py +++ b/src/openllm/utils/dummy_flax_objects.py @@ -19,27 +19,24 @@ from ..utils import DummyMetaclass from ..utils import require_backends if t.TYPE_CHECKING: - from ..models.auto.factory import _LazyAutoMapping + from ..models.auto.factory import _LazyAutoMapping class FlaxFlanT5(metaclass=DummyMetaclass): - _backends = ["flax"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["flax"]) + _backends = ["flax"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["flax"]) class FlaxOPT(metaclass=DummyMetaclass): - _backends = ["flax"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["flax"]) + _backends = ["flax"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["flax"]) class AutoFlaxLLM(metaclass=DummyMetaclass): - _backends = ["flax"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["flax"]) + _backends = ["flax"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["flax"]) MODEL_FLAX_MAPPING = t.cast("_LazyAutoMapping", None) diff --git a/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py b/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py index 6830e5fb..0f2b4031 100644 --- a/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py +++ b/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py @@ -19,14 +19,13 @@ from ..utils import DummyMetaclass from ..utils import require_backends class ChatGLM(metaclass=DummyMetaclass): - _backends = ["torch", "cpm_kernels"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch", "cpm_kernels"]) + _backends = ["torch", "cpm_kernels"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch", "cpm_kernels"]) class Baichuan(metaclass=DummyMetaclass): - _backends = ["torch", "cpm_kernels"] + _backends = ["torch", "cpm_kernels"] - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch", "cpm_kernels"]) + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch", "cpm_kernels"]) diff --git a/src/openllm/utils/dummy_pt_and_einops_objects.py b/src/openllm/utils/dummy_pt_and_einops_objects.py index e10dd6ff..e3b54dfe 100644 --- a/src/openllm/utils/dummy_pt_and_einops_objects.py +++ b/src/openllm/utils/dummy_pt_and_einops_objects.py @@ -19,7 +19,7 @@ from ..utils import DummyMetaclass from ..utils import require_backends class Falcon(metaclass=DummyMetaclass): - _backends = ["torch", "einops"] + _backends = ["torch", "einops"] - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch", "einops"]) + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch", "einops"]) diff --git a/src/openllm/utils/dummy_pt_and_triton_objects.py b/src/openllm/utils/dummy_pt_and_triton_objects.py index e0ff894a..451eb77b 100644 --- a/src/openllm/utils/dummy_pt_and_triton_objects.py +++ b/src/openllm/utils/dummy_pt_and_triton_objects.py @@ -19,7 +19,7 @@ from ..utils import DummyMetaclass from ..utils import require_backends class MPT(metaclass=DummyMetaclass): - _backends = ["torch", "triton"] + _backends = ["torch", "triton"] - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch"]) + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch"]) diff --git a/src/openllm/utils/dummy_pt_objects.py b/src/openllm/utils/dummy_pt_objects.py index c03d2954..5b518fc8 100644 --- a/src/openllm/utils/dummy_pt_objects.py +++ b/src/openllm/utils/dummy_pt_objects.py @@ -18,62 +18,54 @@ from ..utils import DummyMetaclass from ..utils import require_backends if t.TYPE_CHECKING: - from ..models.auto.factory import _LazyAutoMapping + from ..models.auto.factory import _LazyAutoMapping class FlanT5(metaclass=DummyMetaclass): - _backends = ["torch"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch"]) + _backends = ["torch"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch"]) class OPT(metaclass=DummyMetaclass): - _backends = ["torch"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch"]) + _backends = ["torch"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch"]) class GPTNeoX(metaclass=DummyMetaclass): - _backends = ["torch"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch"]) + _backends = ["torch"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch"]) class DollyV2(metaclass=DummyMetaclass): - _backends = ["torch"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch"]) + _backends = ["torch"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch"]) class StarCoder(metaclass=DummyMetaclass): - _backends = ["torch"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch"]) + _backends = ["torch"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch"]) class StableLM(metaclass=DummyMetaclass): - _backends = ["torch"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch"]) + _backends = ["torch"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch"]) class Llama(metaclass=DummyMetaclass): - _backends = ["torch"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch"]) + _backends = ["torch"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch"]) class AutoLLM(metaclass=DummyMetaclass): - _backends = ["torch"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["torch"]) + _backends = ["torch"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["torch"]) MODEL_MAPPING = t.cast("_LazyAutoMapping", None) diff --git a/src/openllm/utils/dummy_tf_objects.py b/src/openllm/utils/dummy_tf_objects.py index ff7d2acd..ee83a12c 100644 --- a/src/openllm/utils/dummy_tf_objects.py +++ b/src/openllm/utils/dummy_tf_objects.py @@ -18,27 +18,24 @@ from ..utils import DummyMetaclass from ..utils import require_backends if t.TYPE_CHECKING: - from ..models.auto.factory import _LazyAutoMapping + from ..models.auto.factory import _LazyAutoMapping class TFFlanT5(metaclass=DummyMetaclass): - _backends = ["tf"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["tf"]) + _backends = ["tf"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["tf"]) class TFOPT(metaclass=DummyMetaclass): - _backends = ["tf"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["tf"]) + _backends = ["tf"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["tf"]) class AutoTFLLM(metaclass=DummyMetaclass): - _backends = ["tf"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["tf"]) + _backends = ["tf"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["tf"]) MODEL_TF_MAPPING = t.cast("_LazyAutoMapping", None) diff --git a/src/openllm/utils/dummy_vllm_objects.py b/src/openllm/utils/dummy_vllm_objects.py index 2e1b2832..bb819e33 100644 --- a/src/openllm/utils/dummy_vllm_objects.py +++ b/src/openllm/utils/dummy_vllm_objects.py @@ -18,26 +18,24 @@ from ..utils import DummyMetaclass from ..utils import require_backends if t.TYPE_CHECKING: - from ..models.auto.factory import _LazyAutoMapping + from ..models.auto.factory import _LazyAutoMapping class VLLMLlama(metaclass=DummyMetaclass): - _backends = ["vllm"] + _backends = ["vllm"] - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["vllm"]) + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["vllm"]) class VLLMOPT(metaclass=DummyMetaclass): - _backends = ["vllm"] - - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["vllm"]) + _backends = ["vllm"] + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["vllm"]) class AutoVLLM(metaclass=DummyMetaclass): - _backends = ["vllm"] + _backends = ["vllm"] - def __init__(self, *args: t.Any, **attrs: t.Any): - require_backends(self, ["vllm"]) + def __init__(self, *args: t.Any, **attrs: t.Any): + require_backends(self, ["vllm"]) - -MODEL_VLLM_MAPPING = t.cast("_LazyAutoMapping", None) +MODEL_VLLM_MAPPING = t.cast("_LazyAutoMapping", None) diff --git a/src/openllm/utils/import_utils.py b/src/openllm/utils/import_utils.py index ac9b587b..9bf4c18a 100644 --- a/src/openllm/utils/import_utils.py +++ b/src/openllm/utils/import_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons.""" from __future__ import annotations import functools @@ -36,37 +35,27 @@ from .representation import ReprMixin # NOTE: We need to do this so that overload can register # correct overloads to typing registry if sys.version_info[:2] >= (3, 11): - from typing import overload + from typing import overload else: - from typing_extensions import overload + from typing_extensions import overload if t.TYPE_CHECKING: - BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]] - from .._types import LiteralRuntime - from .._types import P - from .._types import T + BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]] + from .._types import LiteralRuntime + from .._types import P + from .._types import T - class _AnnotatedLazyLoader(LazyLoader, t.Generic[T]): - DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[[T], t.LiteralString] - PROMPT_MAPPING: dict[T, t.LiteralString] | None + class _AnnotatedLazyLoader(LazyLoader, t.Generic[T]): + DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[[T], t.LiteralString] + PROMPT_MAPPING: dict[T, t.LiteralString] | None else: - _AnnotatedLazyLoader = LazyLoader - BackendOrderredDict = OrderedDict + _AnnotatedLazyLoader = LazyLoader + BackendOrderredDict = OrderedDict logger = logging.getLogger(__name__) -OPTIONAL_DEPENDENCIES = { - "opt", - "flan-t5", - "vllm", - "fine-tune", - "ggml", - "agents", - "openai", - "playground", - "gptq", -} +OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq",} ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) USE_TF = os.environ.get("USE_TF", "AUTO").upper() @@ -74,14 +63,14 @@ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() - def _is_package_available(package: str) -> bool: - _package_available = importlib.util.find_spec(package) is not None - if _package_available: - try: importlib.metadata.version(package) - except importlib.metadata.PackageNotFoundError: _package_available = False - return _package_available - + _package_available = importlib.util.find_spec(package) is not None + if _package_available: + try: + importlib.metadata.version(package) + except importlib.metadata.PackageNotFoundError: + _package_available = False + return _package_available _torch_available = importlib.util.find_spec("torch") is not None _tf_available = importlib.util.find_spec("tensorflow") is not None @@ -98,98 +87,115 @@ _jupytext_available = _is_package_available("jupytext") _notebook_available = _is_package_available("notebook") _autogptq_available = _is_package_available("auto_gptq") -def is_transformers_supports_kbit() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 30) -def is_transformers_supports_agent() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 29) -def is_jupyter_available() -> bool: return _jupyter_available -def is_jupytext_available() -> bool: return _jupytext_available -def is_notebook_available() -> bool: return _notebook_available -def is_triton_available() -> bool: return _triton_available -def is_datasets_available() -> bool: return _datasets_available -def is_peft_available() -> bool: return _peft_available -def is_einops_available() -> bool: return _einops_available -def is_cpm_kernels_available() -> bool: return _cpm_kernel_available -def is_bitsandbytes_available() -> bool: return _bitsandbytes_available -def is_autogptq_available() -> bool: return _autogptq_available -def is_vllm_available() -> bool: return _vllm_available +def is_transformers_supports_kbit() -> bool: + return pkg.pkg_version_info("transformers")[:2] >= (4, 30) + +def is_transformers_supports_agent() -> bool: + return pkg.pkg_version_info("transformers")[:2] >= (4, 29) + +def is_jupyter_available() -> bool: + return _jupyter_available + +def is_jupytext_available() -> bool: + return _jupytext_available + +def is_notebook_available() -> bool: + return _notebook_available + +def is_triton_available() -> bool: + return _triton_available + +def is_datasets_available() -> bool: + return _datasets_available + +def is_peft_available() -> bool: + return _peft_available + +def is_einops_available() -> bool: + return _einops_available + +def is_cpm_kernels_available() -> bool: + return _cpm_kernel_available + +def is_bitsandbytes_available() -> bool: + return _bitsandbytes_available + +def is_autogptq_available() -> bool: + return _autogptq_available + +def is_vllm_available() -> bool: + return _vllm_available def is_torch_available() -> bool: - global _torch_available - if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: - if _torch_available: - try: importlib.metadata.version("torch") - except importlib.metadata.PackageNotFoundError: _torch_available = False - else: - logger.info("Disabling PyTorch because USE_TF is set") + global _torch_available + if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + if _torch_available: + try: + importlib.metadata.version("torch") + except importlib.metadata.PackageNotFoundError: _torch_available = False - return _torch_available + else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + return _torch_available def is_tf_available() -> bool: - global _tf_available - if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: _tf_available = True - else: + global _tf_available + if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: _tf_available = True + else: + _tf_version = None + if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + if _tf_available: + candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos", "tensorflow-aarch64",) _tf_version = None - if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: - if _tf_available: - candidates = ( - "tensorflow", - "tensorflow-cpu", - "tensorflow-gpu", - "tf-nightly", - "tf-nightly-cpu", - "tf-nightly-gpu", - "intel-tensorflow", - "intel-tensorflow-avx512", - "tensorflow-rocm", - "tensorflow-macos", - "tensorflow-aarch64", - ) - _tf_version = None - # For the metadata, we have to look for both tensorflow and tensorflow-cpu - for _pkg in candidates: - try: - _tf_version = importlib.metadata.version(_pkg) - break - except importlib.metadata.PackageNotFoundError: pass - _tf_available = _tf_version is not None - if _tf_available: - if _tf_version and version.parse(_tf_version) < version.parse("2"): - logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version) - _tf_available = False - else: - logger.info("Disabling Tensorflow because USE_TORCH is set") - _tf_available = False - return _tf_available + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for _pkg in candidates: + try: + _tf_version = importlib.metadata.version(_pkg) + break + except importlib.metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if _tf_version and version.parse(_tf_version) < version.parse("2"): + logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version) + _tf_available = False + else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False + return _tf_available def is_flax_available() -> bool: - global _flax_available - if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: - if _flax_available: - try: - importlib.metadata.version("jax") - importlib.metadata.version("flax") - except importlib.metadata.PackageNotFoundError: _flax_available = False - else: _flax_available = False - return _flax_available - + global _flax_available + if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + if _flax_available: + try: + importlib.metadata.version("jax") + importlib.metadata.version("flax") + except importlib.metadata.PackageNotFoundError: + _flax_available = False + else: + _flax_available = False + return _flax_available def requires_dependencies(package: str | list[str], *, extra: str | list[str] | None = None) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]: - import openllm.utils + import openllm.utils - if isinstance(package, str): package = [package] - if isinstance(extra, str): extra = [extra] + if isinstance(package, str): package = [package] + if isinstance(extra, str): extra = [extra] - def decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]: - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any: - for p in package: - cached_check: t.Callable[[], bool] | None = getattr(openllm.utils, f"is_{p}_available", None) - if not ((cached_check is not None and cached_check()) or _is_package_available(p)): raise ImportError( f"{func.__name__} requires '{p}' to be available locally (Currently missing). Make sure to have {p} to be installed: 'pip install \"{p if not extra else 'openllm['+', '.join(extra)+']'}\"'") - return func(*args, **kwargs) + def decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any: + for p in package: + cached_check: t.Callable[[], bool] | None = getattr(openllm.utils, f"is_{p}_available", None) + if not ((cached_check is not None and cached_check()) or _is_package_available(p)): + raise ImportError(f"{func.__name__} requires '{p}' to be available locally (Currently missing). Make sure to have {p} to be installed: 'pip install \"{p if not extra else 'openllm['+', '.join(extra)+']'}\"'") + return func(*args, **kwargs) - return wrapper - - return decorator + return wrapper + return decorator VLLM_IMPORT_ERROR_WITH_PYTORCH = """\ {0} requires the vLLM library but it was not found in your environment. @@ -250,7 +256,6 @@ Checkout the instructions on the installation page: https://www.tensorflow.org/i ones that match your environment. Please note that you may need to restart your runtime after installation. """ - FLAX_IMPORT_ERROR = """{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the installation page: https://github.com/google/flax and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. @@ -301,142 +306,117 @@ You can install it with pip: `pip install auto-gptq`. Please note that you may n your runtime after installation. """ -BACKENDS_MAPPING = BackendOrderredDict( - [ - ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), - ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), - ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), - ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), - ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), - ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)), - ("triton", (is_triton_available, TRITON_IMPORT_ERROR)), - ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), - ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), - ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), - ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), - ] -) - +BACKENDS_MAPPING = BackendOrderredDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)), + ("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)),]) class DummyMetaclass(ABCMeta): - """Metaclass for dummy object. + """Metaclass for dummy object. - It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class. - """ + It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class. + """ - _backends: t.List[str] - - def __getattribute__(cls, key: str) -> t.Any: - if key.startswith("_"): return super().__getattribute__(key) - require_backends(cls, cls._backends) + _backends: t.List[str] + def __getattribute__(cls, key: str) -> t.Any: + if key.startswith("_"): return super().__getattribute__(key) + require_backends(cls, cls._backends) def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None: - if not isinstance(backends, (list, tuple)): backends = list(backends) - name = o.__name__ if hasattr(o, "__name__") else o.__class__.__name__ - # Raise an error for users who might not realize that classes without "TF" are torch-only - if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name)) - # Raise the inverse error for PyTorch users trying to load TF classes - if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name)) - # Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax - if "vllm" in backends: - if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name)) - if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name)) - if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name)) - - checks = (BACKENDS_MAPPING[backend] for backend in backends) - failed = [msg.format(name) for available, msg in checks if not available()] - if failed: raise ImportError("".join(failed)) + if not isinstance(backends, (list, tuple)): backends = list(backends) + name = o.__name__ if hasattr(o, "__name__") else o.__class__.__name__ + # Raise an error for users who might not realize that classes without "TF" are torch-only + if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name)) + # Raise the inverse error for PyTorch users trying to load TF classes + if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name)) + # Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax + if "vllm" in backends: + if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name)) + if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name)) + if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name)) + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: raise ImportError("".join(failed)) class EnvVarMixin(ReprMixin): - model_name: str + model_name: str - @property - def __repr_keys__(self) -> set[str]: - return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"} + @property + def __repr_keys__(self) -> set[str]: + return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"} - if t.TYPE_CHECKING: - config: str - model_id: str - quantize: str - framework: str - bettertransformer: str - runtime: t.Literal["ggml", "transformers"] + if t.TYPE_CHECKING: + config: str + model_id: str + quantize: str + framework: str + bettertransformer: str + runtime: t.Literal["ggml", "transformers"] - framework_value: LiteralRuntime - quantize_value: t.Literal["int8", "int4", "gptq"] | None - bettertransformer_value: bool | None - model_id_value: str | None - runtime_value: t.Literal["ggml", "transformers"] + framework_value: LiteralRuntime + quantize_value: t.Literal["int8", "int4", "gptq"] | None + bettertransformer_value: bool | None + model_id_value: str | None + runtime_value: t.Literal["ggml", "transformers"] - # fmt: off - @overload - def __getitem__(self, item: t.Literal["config"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["model_id"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["quantize"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["framework"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["runtime"]) -> str: ... - @overload - def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ... - @overload - def __getitem__(self, item: t.Literal["quantize_value"]) -> t.Literal["int8", "int4", "gptq"] | None: ... - @overload - def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ... - @overload - def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> bool: ... - @overload - def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ... - # fmt: on + # fmt: off + @overload + def __getitem__(self, item: t.Literal["config"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["model_id"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["quantize"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["framework"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["runtime"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ... + @overload + def __getitem__(self, item: t.Literal["quantize_value"]) -> t.Literal["int8", "int4", "gptq"] | None: ... + @overload + def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ... + @overload + def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ... + # fmt: on - def __getitem__(self, item: str | t.Any) -> t.Any: - if hasattr(self, item): return getattr(self, item) - raise KeyError(f"Key {item} not found in {self}") + def __getitem__(self, item: str | t.Any) -> t.Any: + if hasattr(self, item): return getattr(self, item) + raise KeyError(f"Key {item} not found in {self}") - def __new__( - cls, - model_name: str, - implementation: LiteralRuntime = "pt", - model_id: str | None = None, - bettertransformer: bool | None = None, - quantize: t.LiteralString | None = None, - runtime: t.Literal["ggml", "transformers"] = "transformers", - ) -> t.Self: - from . import codegen - from .._configuration import field_env_key - model_name = inflection.underscore(model_name) + def __new__(cls, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers",) -> t.Self: + from . import codegen + from .._configuration import field_env_key + model_name = inflection.underscore(model_name) - res = super().__new__(cls) - res.model_name = model_name + res = super().__new__(cls) + res.model_name = model_name - # gen properties env key - for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: setattr(res, att, field_env_key(model_name, att.upper())) + # gen properties env key + for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: + setattr(res, att, field_env_key(model_name, att.upper())) - # gen properties env value - attributes_with_values = { - "framework": (str, implementation), - "quantize": (str, quantize), - "bettertransformer": (bool, bettertransformer), - "model_id": (str, model_id), - "runtime": (str, runtime), - } - globs: dict[str, t.Any] = {"__bool_vars_value": ENV_VARS_TRUE_VALUES, "__env_get": os.getenv, "self": res} + # gen properties env value + attributes_with_values = {"framework": (str, implementation), "quantize": (str, quantize), "bettertransformer": (bool, bettertransformer), "model_id": (str, model_id), "runtime": (str, runtime),} + globs: dict[str, t.Any] = {"__bool_vars_value": ENV_VARS_TRUE_VALUES, "__env_get": os.getenv, "self": res} - for attribute, (default_type, default_value) in attributes_with_values.items(): - lines: list[str] = [] - if default_type is bool: lines.append(f"return str(__env_get(self['{attribute}'], str(__env_default)).upper() in __bool_vars_value)") - else: lines.append(f"return __env_get(self['{attribute}'], __env_default)") + for attribute, (default_type, default_value) in attributes_with_values.items(): + lines: list[str] = [] + if default_type is bool: lines.append(f"return str(__env_get(self['{attribute}'], str(__env_default)).upper() in __bool_vars_value)") + else: lines.append(f"return __env_get(self['{attribute}'], __env_default)") - setattr(res, f"{attribute}_value", codegen.generate_function(cls, "_env_get_" + attribute, lines, ("__env_default",), globs)(default_value)) + setattr(res, f"{attribute}_value", codegen.generate_function(cls, "_env_get_" + attribute, lines, ("__env_default",), globs)(default_value)) - return res - @property - def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING") - @property - def module(self) -> _AnnotatedLazyLoader[t.LiteralString]: return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}") + return res + + @property + def start_docstring(self) -> str: + return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING") + + @property + def module(self) -> _AnnotatedLazyLoader[t.LiteralString]: + return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}") diff --git a/src/openllm/utils/lazy.py b/src/openllm/utils/lazy.py index bff0a51e..5fc5972f 100644 --- a/src/openllm/utils/lazy.py +++ b/src/openllm/utils/lazy.py @@ -30,178 +30,176 @@ from ..exceptions import ForbiddenAttributeError from ..exceptions import OpenLLMException class UsageNotAllowedError(OpenLLMException): - """Raised when LazyModule.__getitem__ is forbidden.""" + """Raised when LazyModule.__getitem__ is forbidden.""" + class MissingAttributesError(OpenLLMException): - """Raised when given keys is not available in LazyModule special mapping.""" + """Raised when given keys is not available in LazyModule special mapping.""" @functools.total_ordering @attr.attrs(eq=False, order=False, slots=True, frozen=True) class VersionInfo: - """A version object that can be compared to tuple of length 1--4. + """A version object that can be compared to tuple of length 1--4. - ```python - >>> VersionInfo(19, 1, 0, "final") <= (19, 2) - True - >>> VersionInfo(19, 1, 0, "final") < (19, 1, 1) - True - >>> vi = VersionInfo(19, 2, 0, "final") - >>> vi < (19, 1, 1) - False - >>> vi < (19,) - False - >>> vi == (19, 2,) - True - >>> vi == (19, 2, 1) - False - ``` - Vendorred from attrs. + ```python + >>> VersionInfo(19, 1, 0, "final") <= (19, 2) + True + >>> VersionInfo(19, 1, 0, "final") < (19, 1, 1) + True + >>> vi = VersionInfo(19, 2, 0, "final") + >>> vi < (19, 1, 1) + False + >>> vi < (19,) + False + >>> vi == (19, 2,) + True + >>> vi == (19, 2, 1) + False + ``` + Vendorred from attrs. + """ + major: int = attr.field() + minor: int = attr.field() + micro: int = attr.field() + releaselevel: str = attr.field() + + @classmethod + def from_version_string(cls, s: str) -> VersionInfo: + """Parse *s* and return a VersionInfo.""" + v = s.split(".") + if len(v) == 3: v.append("final") + return cls(major=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3]) + + def _ensure_tuple(self, other: VersionInfo) -> tuple[tuple[int, int, int, str], tuple[int, int, int, str]]: + """Ensure *other* is a tuple of a valid length. + + Returns a possibly transformed *other* and ourselves as a tuple of + the same length as *other*. """ - major: int = attr.field() - minor: int = attr.field() - micro: int = attr.field() - releaselevel: str = attr.field() + cmp = attr.astuple(other) if self.__class__ is other.__class__ else other + if not isinstance(cmp, tuple): raise NotImplementedError + if not (1 <= len(cmp) <= 4): raise NotImplementedError + return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[:len(cmp)]), t.cast(t.Tuple[int, int, int, str], cmp) - @classmethod - def from_version_string(cls, s: str) -> VersionInfo: - """Parse *s* and return a VersionInfo.""" - v = s.split(".") - if len(v) == 3: v.append("final") - return cls(major=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3]) - def _ensure_tuple(self, other: VersionInfo) -> tuple[tuple[int, int, int, str], tuple[int, int, int, str]]: - """Ensure *other* is a tuple of a valid length. + def __eq__(self, other: t.Any) -> bool: + try: + us, them = self._ensure_tuple(other) + except NotImplementedError: + return NotImplemented + return us == them - Returns a possibly transformed *other* and ourselves as a tuple of - the same length as *other*. - """ - cmp = attr.astuple(other) if self.__class__ is other.__class__ else other - if not isinstance(cmp, tuple): raise NotImplementedError - if not (1 <= len(cmp) <= 4): raise NotImplementedError - return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[: len(cmp)]), t.cast(t.Tuple[int, int, int, str], cmp) - def __eq__(self, other: t.Any) -> bool: - try: us, them = self._ensure_tuple(other) - except NotImplementedError: return NotImplemented - return us == them - def __lt__(self, other: t.Any) -> bool: - try: us, them = self._ensure_tuple(other) - except NotImplementedError: return NotImplemented - # Since alphabetically "dev0" < "final" < "post1" < "post2", we don't - # have to do anything special with releaselevel for now. - return us < them + def __lt__(self, other: t.Any) -> bool: + try: + us, them = self._ensure_tuple(other) + except NotImplementedError: + return NotImplemented + # Since alphabetically "dev0" < "final" < "post1" < "post2", we don't + # have to do anything special with releaselevel for now. + return us < them _sentinel, _reserved_namespace = object(), {"__openllm_special__", "__openllm_migration__"} class LazyModule(types.ModuleType): - """Module class that surfaces all objects but only performs associated imports when the objects are requested. + """Module class that surfaces all objects but only performs associated imports when the objects are requested. - This is a direct port from transformers.utils.import_utils._LazyModule for backwards compatibility with transformers < 4.18. + This is a direct port from transformers.utils.import_utils._LazyModule for backwards compatibility with transformers < 4.18. - This is an extension a more powerful LazyLoader. + This is an extension a more powerful LazyLoader. + """ + + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__(self, name: str, module_file: str, import_structure: dict[str, list[str]], module_spec: importlib.machinery.ModuleSpec | None = None, doc: str | None = None, extra_objects: dict[str, t.Any] | None = None,): + """Lazily load this module as an object. + + It does instantiate a __all__ and __dir__ for IDE support + + Args: + name: module name + module_file: the given file. Often default to 'globals()['__file__']' + import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module' + module_spec: __spec__ of the lazily loaded module + doc: Optional docstring for this module. + extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well + as any locals() functions """ + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module: dict[str, str] = {} + _extra_objects = {} if extra_objects is None else extra_objects + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list(itertools.chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self.__doc__ = doc + self._objects = _extra_objects + self._name = name + self._import_structure = import_structure - # Very heavily inspired by optuna.integration._IntegrationModule - # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py - def __init__( - self, - name: str, - module_file: str, - import_structure: dict[str, list[str]], - module_spec: importlib.machinery.ModuleSpec | None = None, - doc: str | None = None, - extra_objects: dict[str, t.Any] | None = None, - ): - """Lazily load this module as an object. + def __dir__(self) -> list[str]: + """Needed for autocompletion in an IDE.""" + result = t.cast("list[str]", super().__dir__()) + # The elements of self.__all__ that are submodules may or + # may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the + # elements of self.__all__ that are not already in the dir. + return result + [i for i in self.__all__ if i not in result] - It does instantiate a __all__ and __dir__ for IDE support + def __getitem__(self, key: str) -> t.Any: + """This is reserved to only internal uses and users shouldn't use this.""" + if self._objects.get("__openllm_special__") is None: raise UsageNotAllowedError(f"'{self._name}' is not allowed to be used as a dict.") + _special_mapping = self._objects.get("__openllm_special__", {}) + try: + if key in _special_mapping: return getattr(self, _special_mapping.__getitem__(key)) + raise MissingAttributesError(f"Requested '{key}' is not available in given mapping.") + except AttributeError as e: + raise KeyError(f"'{self._name}' has no attribute {_special_mapping[key]}") from e + except Exception as e: + raise KeyError(f"Failed to lookup '{key}' in '{self._name}'") from e - Args: - name: module name - module_file: the given file. Often default to 'globals()['__file__']' - import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module' - module_spec: __spec__ of the lazily loaded module - doc: Optional docstring for this module. - extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well - as any locals() functions - """ - super().__init__(name) - self._modules = set(import_structure.keys()) - self._class_to_module: dict[str, str] = {} - _extra_objects = {} if extra_objects is None else extra_objects - for key, values in import_structure.items(): - for value in values: - self._class_to_module[value] = key - # Needed for autocompletion in an IDE - self.__all__ = list(import_structure.keys()) + list(itertools.chain(*import_structure.values())) - self.__file__ = module_file - self.__spec__ = module_spec - self.__path__ = [os.path.dirname(module_file)] - self.__doc__ = doc - self._objects = _extra_objects - self._name = name - self._import_structure = import_structure - def __dir__(self) -> list[str]: - """Needed for autocompletion in an IDE.""" - result = t.cast("list[str]", super().__dir__()) - # The elements of self.__all__ that are submodules may or - # may not be in the dir already, depending on whether - # they have been accessed or not. So we only add the - # elements of self.__all__ that are not already in the dir. - return result + [i for i in self.__all__ if i not in result] - def __getitem__(self, key: str) -> t.Any: - """This is reserved to only internal uses and users shouldn't use this.""" - if self._objects.get("__openllm_special__") is None: raise UsageNotAllowedError(f"'{self._name}' is not allowed to be used as a dict.") - _special_mapping = self._objects.get("__openllm_special__", {}) - try: - if key in _special_mapping: return getattr(self, _special_mapping.__getitem__(key)) - raise MissingAttributesError(f"Requested '{key}' is not available in given mapping.") - except AttributeError as e: raise KeyError(f"'{self._name}' has no attribute {_special_mapping[key]}") from e - except Exception as e: raise KeyError(f"Failed to lookup '{key}' in '{self._name}'") from e - def __getattr__(self, name: str) -> t.Any: - """Equivocal __getattr__ implementation. + def __getattr__(self, name: str) -> t.Any: + """Equivocal __getattr__ implementation. - It checks from _objects > _modules and does it recursively. + It checks from _objects > _modules and does it recursively. - It also contains a special case for all of the metadata information, such as __version__ and __version_info__. - """ - if name in _reserved_namespace: raise ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.") - dunder_to_metadata = { - "__title__": "Name", - "__copyright__": "", - "__version__": "version", - "__version_info__": "version", - "__description__": "summary", - "__uri__": "", - "__url__": "", - "__author__": "", - "__email__": "", - "__license__": "license", - "__homepage__": "", - } - if name in dunder_to_metadata: - if name not in {"__version_info__", "__copyright__", "__version__"}: warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2) - meta = importlib.metadata.metadata("openllm") - project_url = dict(url.split(", ") for url in meta.get_all("Project-URL")) - if name == "__license__": return "Apache-2.0" - elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al." - elif name in ("__uri__", "__url__"): return project_url["GitHub"] - elif name == "__homepage__": return project_url["Homepage"] - elif name == "__version_info__": return VersionInfo.from_version_string(meta["version"]) # similar to how attrs handle __version_info__ - elif name == "__author__": return meta["Author-email"].rsplit(" ", 1)[0] - elif name == "__email__": return meta["Author-email"].rsplit("<", 1)[1][:-1] - return meta[dunder_to_metadata[name]] - if "__openllm_migration__" in self._objects: - cur_value = self._objects["__openllm_migration__"].get(name, _sentinel) - if cur_value is not _sentinel: - warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3) - return getattr(self, cur_value) - if name in self._objects: return self._objects.__getitem__(name) - if name in self._modules: value = self._get_module(name) - elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name) - else: raise AttributeError(f"module {self.__name__} has no attribute {name}") - setattr(self, name, value) - return value - def _get_module(self, module_name: str) -> types.ModuleType: - try: return importlib.import_module("." + module_name, self.__name__) - except Exception as e: raise RuntimeError(f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}") from e - def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]: - """This is to ensure any given module is pickle-able.""" - return (self.__class__, (self._name, self.__file__, self._import_structure)) + It also contains a special case for all of the metadata information, such as __version__ and __version_info__. + """ + if name in _reserved_namespace: raise ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.") + dunder_to_metadata = {"__title__": "Name", "__copyright__": "", "__version__": "version", "__version_info__": "version", "__description__": "summary", "__uri__": "", "__url__": "", "__author__": "", "__email__": "", "__license__": "license", "__homepage__": "",} + if name in dunder_to_metadata: + if name not in {"__version_info__", "__copyright__", "__version__"}: + warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2) + meta = importlib.metadata.metadata("openllm") + project_url = dict(url.split(", ") for url in meta.get_all("Project-URL")) + if name == "__license__": return "Apache-2.0" + elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al." + elif name in ("__uri__", "__url__"): return project_url["GitHub"] + elif name == "__homepage__": return project_url["Homepage"] + elif name == "__version_info__": return VersionInfo.from_version_string(meta["version"]) # similar to how attrs handle __version_info__ + elif name == "__author__": return meta["Author-email"].rsplit(" ", 1)[0] + elif name == "__email__": return meta["Author-email"].rsplit("<", 1)[1][:-1] + return meta[dunder_to_metadata[name]] + if "__openllm_migration__" in self._objects: + cur_value = self._objects["__openllm_migration__"].get(name, _sentinel) + if cur_value is not _sentinel: + warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3) + return getattr(self, cur_value) + if name in self._objects: return self._objects.__getitem__(name) + if name in self._modules: value = self._get_module(name) + elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name) + else: raise AttributeError(f"module {self.__name__} has no attribute {name}") + setattr(self, name, value) + return value + + def _get_module(self, module_name: str) -> types.ModuleType: + try: + return importlib.import_module("." + module_name, self.__name__) + except Exception as e: + raise RuntimeError(f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}") from e + + def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]: + """This is to ensure any given module is pickle-able.""" + return (self.__class__, (self._name, self.__file__, self._import_structure)) diff --git a/src/openllm/utils/representation.py b/src/openllm/utils/representation.py index 30210853..dfb31810 100644 --- a/src/openllm/utils/representation.py +++ b/src/openllm/utils/representation.py @@ -20,53 +20,51 @@ import attr import orjson if t.TYPE_CHECKING: - ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]] - + ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]] class ReprMixin: - """This class display possible representation of given class. + """This class display possible representation of given class. - It can be used for implementing __rich_pretty__ and __pretty__ methods in the future. - Most subclass needs to implement a __repr_keys__ property. + It can be used for implementing __rich_pretty__ and __pretty__ methods in the future. + Most subclass needs to implement a __repr_keys__ property. - Based on the design from Pydantic. - The __repr__ will display the json representation of the object for easier interaction. - The __str__ will display either __attrs_repr__ or __repr_str__. + Based on the design from Pydantic. + The __repr__ will display the json representation of the object for easier interaction. + The __str__ will display either __attrs_repr__ or __repr_str__. + """ + @property + @abstractmethod + def __repr_keys__(self) -> set[str]: + """This can be overriden by base class using this mixin.""" + + def __repr__(self) -> str: + """The `__repr__` for any subclass of Mixin. + + It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict. """ + from . import bentoml_cattr - @property - @abstractmethod - def __repr_keys__(self) -> set[str]: - """This can be overriden by base class using this mixin.""" + serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()} + return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}" - def __repr__(self) -> str: - """The `__repr__` for any subclass of Mixin. + def __str__(self) -> str: + """The string representation of the given Mixin subclass. - It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict. - """ - from . import bentoml_cattr + It will contains all of the attributes from __repr_keys__ + """ + return self.__repr_str__(" ") - serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()} - return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}" + def __repr_name__(self) -> str: + """Name of the instance's class, used in __repr__.""" + return self.__class__.__name__ - def __str__(self) -> str: - """The string representation of the given Mixin subclass. + def __repr_str__(self, join_str: str) -> str: + """To be used with __str__.""" + return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__()) - It will contains all of the attributes from __repr_keys__ - """ - return self.__repr_str__(" ") + def __repr_args__(self) -> ReprArgs: + """This can also be overriden by base class using this mixin. - def __repr_name__(self) -> str: - """Name of the instance's class, used in __repr__.""" - return self.__class__.__name__ - - def __repr_str__(self, join_str: str) -> str: - """To be used with __str__.""" - return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__()) - - def __repr_args__(self) -> ReprArgs: - """This can also be overriden by base class using this mixin. - - By default it does a getattr of the current object from __repr_keys__. - """ - return ((k, getattr(self, k)) for k in self.__repr_keys__) + By default it does a getattr of the current object from __repr_keys__. + """ + return ((k, getattr(self, k)) for k in self.__repr_keys__) diff --git a/src/openllm_client/__init__.py b/src/openllm_client/__init__.py index 3f1e7f20..35d995c3 100644 --- a/src/openllm_client/__init__.py +++ b/src/openllm_client/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """The actual client implementation. Use ``openllm.client`` instead. diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index dc18eba0..6151c272 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -28,240 +28,316 @@ import openllm # NOTE: We need to do this so that overload can register # correct overloads to typing registry if sys.version_info[:2] >= (3, 11): - from typing import overload + from typing import overload else: - from typing_extensions import overload + from typing_extensions import overload if t.TYPE_CHECKING: - import transformers - from openllm._types import DictStrAny - from openllm._types import LiteralRuntime - class AnnotatedClient(bentoml.client.Client): - def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: ... - async def async_health(self) -> t.Any: ... - def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]: ... - def metadata_v1(self) -> dict[str, t.Any]: ... - def embeddings_v1(self) -> t.Sequence[float]: ... -else: transformers, DictStrAny = openllm.utils.LazyLoader("transformers", globals(), "transformers"), dict + import transformers + from openllm._types import DictStrAny + from openllm._types import LiteralRuntime + + class AnnotatedClient(bentoml.client.Client): + def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: + ... + + async def async_health(self) -> t.Any: + ... + + def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]: + ... + + def metadata_v1(self) -> dict[str, t.Any]: + ... + + def embeddings_v1(self) -> t.Sequence[float]: + ... +else: + + transformers, DictStrAny = openllm.utils.LazyLoader("transformers", globals(), "transformers"), dict logger = logging.getLogger(__name__) def in_async_context() -> bool: - try: - _ = asyncio.get_running_loop() - return True - except RuntimeError: return False + try: + _ = asyncio.get_running_loop() + return True + except RuntimeError: + return False T = t.TypeVar("T") class ClientMeta(t.Generic[T]): - _api_version: str - _client_class: type[bentoml.client.Client] - _host: str - _port: str + _api_version: str + _client_class: type[bentoml.client.Client] + _host: str + _port: str - __client__: AnnotatedClient | None = None - __agent__: transformers.HfAgent | None = None - __llm__: openllm.LLM[t.Any, t.Any] | None = None + __client__: AnnotatedClient | None = None + __agent__: transformers.HfAgent | None = None + __llm__: openllm.LLM[t.Any, t.Any] | None = None - def __init__(self, address: str, timeout: int = 30): - self._address = address - self._timeout = timeout - def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): - """Initialise subclass for HTTP and gRPC client type.""" - cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient - cls._api_version = api_version - @property - def _hf_agent(self) -> transformers.HfAgent: - if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.") - if self.__agent__ is None: - if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'") - self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent")) - return self.__agent__ - @property - def _metadata(self) -> T: - if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() - return self.call("metadata") - @property - @abstractmethod - def model_name(self) -> str: raise NotImplementedError - @property - @abstractmethod - def framework(self) -> LiteralRuntime: raise NotImplementedError - @property - @abstractmethod - def timeout(self) -> int: raise NotImplementedError - @property - @abstractmethod - def model_id(self) -> str: raise NotImplementedError - @property - @abstractmethod - def configuration(self) -> dict[str, t.Any]: raise NotImplementedError - @property - @abstractmethod - def supports_embeddings(self) -> bool: raise NotImplementedError - @property - @abstractmethod - def supports_hf_agent(self) -> bool: raise NotImplementedError - @property - def llm(self) -> openllm.LLM[t.Any, t.Any]: - if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name) - return self.__llm__ - @property - def config(self) -> openllm.LLMConfig: return self.llm.config - def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return self._cached.call(f"{name}_{self._api_version}", *args, **attrs) - async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs) - @property - def _cached(self) -> AnnotatedClient: - if self.__client__ is None: - self._client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout) - self.__client__ = t.cast("AnnotatedClient", self._client_class.from_url(self._address)) - return self.__client__ - @abstractmethod - def postprocess(self, result: t.Any) -> openllm.GenerationOutput: ... - @abstractmethod - def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ... + def __init__(self, address: str, timeout: int = 30): + self._address = address + self._timeout = timeout + + def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): + """Initialise subclass for HTTP and gRPC client type.""" + cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient + cls._api_version = api_version + + @property + def _hf_agent(self) -> transformers.HfAgent: + if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.") + if self.__agent__ is None: + if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'") + self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent")) + return self.__agent__ + + @property + def _metadata(self) -> T: + if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() + return self.call("metadata") + + @property + @abstractmethod + def model_name(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def framework(self) -> LiteralRuntime: + raise NotImplementedError + + @property + @abstractmethod + def timeout(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def model_id(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def configuration(self) -> dict[str, t.Any]: + raise NotImplementedError + + @property + @abstractmethod + def supports_embeddings(self) -> bool: + raise NotImplementedError + + @property + @abstractmethod + def supports_hf_agent(self) -> bool: + raise NotImplementedError + + @property + def llm(self) -> openllm.LLM[t.Any, t.Any]: + if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name) + return self.__llm__ + + @property + def config(self) -> openllm.LLMConfig: + return self.llm.config + + def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: + return self._cached.call(f"{name}_{self._api_version}", *args, **attrs) + + async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: + return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs) + + @property + def _cached(self) -> AnnotatedClient: + if self.__client__ is None: + self._client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout) + self.__client__ = t.cast("AnnotatedClient", self._client_class.from_url(self._address)) + return self.__client__ + + @abstractmethod + def postprocess(self, result: t.Any) -> openllm.GenerationOutput: + ... + + @abstractmethod + def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + ... class BaseClient(ClientMeta[T]): - def health(self) -> t.Any: raise NotImplementedError - def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError - def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError - @overload - def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ... - @overload - def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ... - @overload - def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ... - def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: - return_raw_response = attrs.pop("return_raw_response", None) - if return_raw_response is not None: - logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.") - if return_raw_response is True: return_response = "raw" - return_attrs = attrs.pop("return_attrs", None) - if return_attrs is not None: - logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.") - if return_attrs is True: return_response = "attrs" - use_default_prompt_template = attrs.pop("use_default_prompt_template", False) - prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs) + def health(self) -> t.Any: + raise NotImplementedError - inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)) - if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/generate"), json=inputs.model_dump(), timeout=self.timeout).json() - else: result = self.call("generate", inputs.model_dump()) - r = self.postprocess(result) - if return_response == "attrs": return r - elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r) - else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs) - # NOTE: Scikit interface - @overload - def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ... - @overload - def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ... - @overload - def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ... - def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs)) - def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any: - if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) - else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'") - def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: - if len(args) > 1: raise ValueError("'args' should only take one positional argument.") - task = kwargs.pop("task", args[0]) - return_code = kwargs.pop("return_code", False) - remote = kwargs.pop("remote", False) - try: - return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs) - except Exception as err: - logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err) - logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address) + def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: + raise NotImplementedError + def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: + raise NotImplementedError + + @overload + def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: + ... + + @overload + def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: + ... + + @overload + def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: + ... + + def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: + return_raw_response = attrs.pop("return_raw_response", None) + if return_raw_response is not None: + logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.") + if return_raw_response is True: return_response = "raw" + return_attrs = attrs.pop("return_attrs", None) + if return_attrs is not None: + logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.") + if return_attrs is True: return_response = "attrs" + use_default_prompt_template = attrs.pop("use_default_prompt_template", False) + prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs) + + inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)) + if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/generate"), json=inputs.model_dump(), timeout=self.timeout).json() + else: result = self.call("generate", inputs.model_dump()) + r = self.postprocess(result) + if return_response == "attrs": return r + elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r) + else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs) + + # NOTE: Scikit interface + @overload + def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: + ... + + @overload + def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: + ... + + @overload + def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: + ... + + def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: + return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs)) + + def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any: + if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) + else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'") + + def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + if len(args) > 1: raise ValueError("'args' should only take one positional argument.") + task = kwargs.pop("task", args[0]) + return_code = kwargs.pop("return_code", False) + remote = kwargs.pop("remote", False) + try: + return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs) + except Exception as err: + logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err) + logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address) class BaseAsyncClient(ClientMeta[T]): - async def health(self) -> t.Any: raise NotImplementedError - async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError - async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError - @overload - async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ... - @overload - async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ... - @overload - async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ... - async def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: - return_raw_response = attrs.pop("return_raw_response", None) - if return_raw_response is not None: - logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.") - if return_raw_response is True: return_response = "raw" - return_attrs = attrs.pop("return_attrs", None) - if return_attrs is not None: - logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.") - if return_attrs is True: return_response = "attrs" - use_default_prompt_template = attrs.pop("use_default_prompt_template", False) - prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs) + async def health(self) -> t.Any: + raise NotImplementedError - inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)) - res = await self.acall("generate", inputs.model_dump()) - r = self.postprocess(res) + async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: + raise NotImplementedError - if return_response == "attrs": return r - elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r) - else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs) - # NOTE: Scikit interface - @overload - async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ... - @overload - async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ... - @overload - async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ... - async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs)) + async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: + raise NotImplementedError - async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any: - """Async version of agent.run.""" - if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) - else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'") - async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: - if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0") - if len(args) > 1: raise ValueError("'args' should only take one positional argument.") - task = kwargs.pop("task", args[0]) - return_code = kwargs.pop("return_code", False) - remote = kwargs.pop("remote", False) + @overload + async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: + ... - from transformers.tools.agents import clean_code_for_run - from transformers.tools.agents import get_tool_creation_code - from transformers.tools.agents import resolve_tools - from transformers.tools.python_interpreter import evaluate + @overload + async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: + ... - _hf_agent = self._hf_agent + @overload + async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: + ... - prompt = t.cast(str, _hf_agent.format_prompt(task)) - stop = ["Task:"] - async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client: - response = await client.post( - _hf_agent.url_endpoint, - json={ - "inputs": prompt, - "parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop}, - }, - ) - if response.status_code != HTTPStatus.OK: - raise ValueError(f"Error {response.status_code}: {response.json()}") + async def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: + return_raw_response = attrs.pop("return_raw_response", None) + if return_raw_response is not None: + logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.") + if return_raw_response is True: return_response = "raw" + return_attrs = attrs.pop("return_attrs", None) + if return_attrs is not None: + logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.") + if return_attrs is True: return_response = "attrs" + use_default_prompt_template = attrs.pop("use_default_prompt_template", False) + prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs) - result = response.json()[0]["generated_text"] - # Inference API returns the stop sequence - for stop_seq in stop: - if result.endswith(stop_seq): - result = result[: -len(stop_seq)] - break + inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)) + res = await self.acall("generate", inputs.model_dump()) + r = self.postprocess(res) - # the below have the same logic as agent.run API - explanation, code = clean_code_for_run(result) + if return_response == "attrs": return r + elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r) + else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs) - _hf_agent.log(f"==Explanation from the agent==\n{explanation}") + # NOTE: Scikit interface + @overload + async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: + ... - _hf_agent.log(f"\n\n==Code generated by the agent==\n{code}") - if not return_code: - _hf_agent.log("\n\n==Result==") - _hf_agent.cached_tools = resolve_tools( - code, _hf_agent.toolbox, remote=remote, cached_tools=_hf_agent.cached_tools - ) - return evaluate(code, _hf_agent.cached_tools, state=kwargs.copy()) - else: - tool_code = get_tool_creation_code(code, _hf_agent.toolbox, remote=remote) - return f"{tool_code}\n{code}" + @overload + async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: + ... + + @overload + async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: + ... + + async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: + return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs)) + + async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any: + """Async version of agent.run.""" + if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) + else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'") + + async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0") + if len(args) > 1: raise ValueError("'args' should only take one positional argument.") + task = kwargs.pop("task", args[0]) + return_code = kwargs.pop("return_code", False) + remote = kwargs.pop("remote", False) + + from transformers.tools.agents import clean_code_for_run + from transformers.tools.agents import get_tool_creation_code + from transformers.tools.agents import resolve_tools + from transformers.tools.python_interpreter import evaluate + + _hf_agent = self._hf_agent + + prompt = t.cast(str, _hf_agent.format_prompt(task)) + stop = ["Task:"] + async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client: + response = await client.post(_hf_agent.url_endpoint, json={"inputs": prompt, "parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},},) + if response.status_code != HTTPStatus.OK: + raise ValueError(f"Error {response.status_code}: {response.json()}") + + result = response.json()[0]["generated_text"] + # Inference API returns the stop sequence + for stop_seq in stop: + if result.endswith(stop_seq): + result = result[:-len(stop_seq)] + break + + # the below have the same logic as agent.run API + explanation, code = clean_code_for_run(result) + + _hf_agent.log(f"==Explanation from the agent==\n{explanation}") + + _hf_agent.log(f"\n\n==Code generated by the agent==\n{code}") + if not return_code: + _hf_agent.log("\n\n==Result==") + _hf_agent.cached_tools = resolve_tools(code, _hf_agent.toolbox, remote=remote, cached_tools=_hf_agent.cached_tools) + return evaluate(code, _hf_agent.cached_tools, state=kwargs.copy()) + else: + tool_code = get_tool_creation_code(code, _hf_agent.toolbox, remote=remote) + return f"{tool_code}\n{code}" diff --git a/src/openllm_client/runtimes/grpc.py b/src/openllm_client/runtimes/grpc.py index a859c5d6..7e913f79 100644 --- a/src/openllm_client/runtimes/grpc.py +++ b/src/openllm_client/runtimes/grpc.py @@ -25,96 +25,93 @@ from .base import BaseAsyncClient from .base import BaseClient if t.TYPE_CHECKING: - from grpc_health.v1 import health_pb2 + from grpc_health.v1 import health_pb2 - from bentoml.grpc.v1.service_pb2 import Response - from openllm._types import LiteralRuntime + from bentoml.grpc.v1.service_pb2 import Response + from openllm._types import LiteralRuntime logger = logging.getLogger(__name__) - class GrpcClientMixin: - if t.TYPE_CHECKING: - - @property - def _metadata(self) -> Response: - ... + if t.TYPE_CHECKING: @property - def model_name(self) -> str: - try: - return self._metadata.json.struct_value.fields["model_name"].string_value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + def _metadata(self) -> Response: + ... - @property - def framework(self) -> LiteralRuntime: - try: - value = self._metadata.json.struct_value.fields["framework"].string_value - if value not in ("pt", "flax", "tf"): - raise KeyError - return value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def model_name(self) -> str: + try: + return self._metadata.json.struct_value.fields["model_name"].string_value + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def timeout(self) -> int: - try: - return int(self._metadata.json.struct_value.fields["timeout"].number_value) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def framework(self) -> LiteralRuntime: + try: + value = self._metadata.json.struct_value.fields["framework"].string_value + if value not in ("pt", "flax", "tf"): + raise KeyError + return value + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def model_id(self) -> str: - try: - return self._metadata.json.struct_value.fields["model_id"].string_value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def timeout(self) -> int: + try: + return int(self._metadata.json.struct_value.fields["timeout"].number_value) + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def configuration(self) -> dict[str, t.Any]: - try: - v = self._metadata.json.struct_value.fields["configuration"].string_value - return orjson.loads(v) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def model_id(self) -> str: + try: + return self._metadata.json.struct_value.fields["model_id"].string_value + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def supports_embeddings(self) -> bool: - try: - return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def configuration(self) -> dict[str, t.Any]: + try: + v = self._metadata.json.struct_value.fields["configuration"].string_value + return orjson.loads(v) + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def supports_hf_agent(self) -> bool: - try: - return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def supports_embeddings(self) -> bool: + try: + return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput: - if isinstance(result, dict): - return openllm.GenerationOutput(**result) + @property + def supports_hf_agent(self) -> bool: + try: + return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - from google.protobuf.json_format import MessageToDict + def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput: + if isinstance(result, dict): + return openllm.GenerationOutput(**result) - return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True)) + from google.protobuf.json_format import MessageToDict + return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True)) class GrpcClient(GrpcClientMixin, BaseClient["Response"], client_type="grpc"): - def __init__(self, address: str, timeout: int = 30): - self._host, self._port = address.split(":") - super().__init__(address, timeout) - - def health(self) -> health_pb2.HealthCheckResponse: - return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService")) + def __init__(self, address: str, timeout: int = 30): + self._host, self._port = address.split(":") + super().__init__(address, timeout) + def health(self) -> health_pb2.HealthCheckResponse: + return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService")) class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient["Response"], client_type="grpc"): - def __init__(self, address: str, timeout: int = 30): - self._host, self._port = address.split(":") - super().__init__(address, timeout) + def __init__(self, address: str, timeout: int = 30): + self._host, self._port = address.split(":") + super().__init__(address, timeout) - async def health(self) -> health_pb2.HealthCheckResponse: - return await self._cached.health("bentoml.grpc.v1.BentoService") + async def health(self) -> health_pb2.HealthCheckResponse: + return await self._cached.health("bentoml.grpc.v1.BentoService") diff --git a/src/openllm_client/runtimes/http.py b/src/openllm_client/runtimes/http.py index f3e31919..78cf078b 100644 --- a/src/openllm_client/runtimes/http.py +++ b/src/openllm_client/runtimes/http.py @@ -28,71 +28,101 @@ from .base import BaseClient from .base import in_async_context if t.TYPE_CHECKING: - from openllm._types import DictStrAny - from openllm._types import LiteralRuntime + from openllm._types import DictStrAny + from openllm._types import LiteralRuntime else: - DictStrAny = dict + DictStrAny = dict logger = logging.getLogger(__name__) class HTTPClientMixin: - if t.TYPE_CHECKING: - @property - def _metadata(self) -> DictStrAny: ... + if t.TYPE_CHECKING: + @property - def model_name(self) -> str: - try: return self._metadata["model_name"] - except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def model_id(self) -> str: - try: return self._metadata["model_name"] - except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def framework(self) -> LiteralRuntime: - try: return self._metadata["framework"] - except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def timeout(self) -> int: - try: return self._metadata["timeout"] - except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def configuration(self) -> dict[str, t.Any]: - try: return orjson.loads(self._metadata["configuration"]) - except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def supports_embeddings(self) -> bool: - try: return self._metadata.get("supports_embeddings", False) - except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - @property - def supports_hf_agent(self) -> bool: - try: return self._metadata.get("supports_hf_agent", False) - except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: - return openllm.GenerationOutput(**result) + def _metadata(self) -> DictStrAny: + ... + + @property + def model_name(self) -> str: + try: + return self._metadata["model_name"] + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + + @property + def model_id(self) -> str: + try: + return self._metadata["model_name"] + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + + @property + def framework(self) -> LiteralRuntime: + try: + return self._metadata["framework"] + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + + @property + def timeout(self) -> int: + try: + return self._metadata["timeout"] + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + + @property + def configuration(self) -> dict[str, t.Any]: + try: + return orjson.loads(self._metadata["configuration"]) + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + + @property + def supports_embeddings(self) -> bool: + try: + return self._metadata.get("supports_embeddings", False) + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + + @property + def supports_hf_agent(self) -> bool: + try: + return self._metadata.get("supports_hf_agent", False) + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + + def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: + return openllm.GenerationOutput(**result) class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]): - def __init__(self, address: str, timeout: int = 30): - address = address if "://" in address else "http://" + address - self._host, self._port = urlparse(address).netloc.split(":") - super().__init__(address, timeout) - def health(self) -> t.Any: return self._cached.health() - def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: - if not self.supports_embeddings: - raise ValueError("This model does not support embeddings.") - if isinstance(prompt, str): prompt = [prompt] - if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout) - else: result = self.call("embeddings", list(prompt)) - return openllm.EmbeddingsOutput(**result) + def __init__(self, address: str, timeout: int = 30): + address = address if "://" in address else "http://" + address + self._host, self._port = urlparse(address).netloc.split(":") + super().__init__(address, timeout) + + def health(self) -> t.Any: + return self._cached.health() + + def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: + if not self.supports_embeddings: + raise ValueError("This model does not support embeddings.") + if isinstance(prompt, str): prompt = [prompt] + if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout) + else: result = self.call("embeddings", list(prompt)) + return openllm.EmbeddingsOutput(**result) class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]): - def __init__(self, address: str, timeout: int = 30): - address = address if "://" in address else "http://" + address - self._host, self._port = urlparse(address).netloc.split(":") - super().__init__(address, timeout) - async def health(self) -> t.Any: return await self._cached.async_health() - async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: - if not self.supports_embeddings: - raise ValueError("This model does not support embeddings.") - if isinstance(prompt, str): prompt = [prompt] - res = await self.acall("embeddings", list(prompt)) - return openllm.EmbeddingsOutput(**res) + def __init__(self, address: str, timeout: int = 30): + address = address if "://" in address else "http://" + address + self._host, self._port = urlparse(address).netloc.split(":") + super().__init__(address, timeout) + + async def health(self) -> t.Any: + return await self._cached.async_health() + + async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: + if not self.supports_embeddings: + raise ValueError("This model does not support embeddings.") + if isinstance(prompt, str): prompt = [prompt] + res = await self.acall("embeddings", list(prompt)) + return openllm.EmbeddingsOutput(**res) diff --git a/tests/__init__.py b/tests/__init__.py index 5e960cb1..62114cc8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,5 +18,4 @@ from hypothesis import settings settings.register_profile("CI", settings(suppress_health_check=[HealthCheck.too_slow]), deadline=None) -if "CI" in os.environ: - settings.load_profile("CI") +if "CI" in os.environ: settings.load_profile("CI") diff --git a/tests/_strategies/_configuration.py b/tests/_strategies/_configuration.py index ab574f69..5cc51585 100644 --- a/tests/_strategies/_configuration.py +++ b/tests/_strategies/_configuration.py @@ -25,55 +25,39 @@ logger = logging.getLogger(__name__) env_strats = st.sampled_from([openllm.utils.EnvVarMixin(model_name) for model_name in openllm.CONFIG_MAPPING.keys()]) - @st.composite def model_settings(draw: st.DrawFn): - """Strategy for generating ModelSettings objects.""" - kwargs: dict[str, t.Any] = { - "default_id": st.text(min_size=1), - "model_ids": st.lists(st.text(), min_size=1), - "architecture": st.text(min_size=1), - "url": st.text(), - "requires_gpu": st.booleans(), - "trust_remote_code": st.booleans(), - "requirements": st.none() | st.lists(st.text(), min_size=1), - "default_implementation": st.dictionaries(st.sampled_from(["cpu", "nvidia.com/gpu"]), st.sampled_from(["vllm", "pt", "tf", "flax"])), - "model_type": st.sampled_from(["causal_lm", "seq2seq_lm"]), - "runtime": st.sampled_from(["transformers", "ggml"]), - "name_type": st.sampled_from(["dasherize", "lowercase"]), - "timeout": st.integers(min_value=3600), - "workers_per_resource": st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)), - } - return draw(st.builds(ModelSettings, **kwargs)) + """Strategy for generating ModelSettings objects.""" + kwargs: dict[str, t.Any] = { + "default_id": st.text(min_size=1), "model_ids": st.lists(st.text(), min_size=1), "architecture": st.text(min_size=1), "url": st.text(), "requires_gpu": st.booleans(), "trust_remote_code": st.booleans(), "requirements": st.none() + | st.lists(st.text(), min_size=1), "default_implementation": st.dictionaries(st.sampled_from(["cpu", "nvidia.com/gpu"]), st.sampled_from(["vllm", "pt", "tf", "flax"])), "model_type": st.sampled_from(["causal_lm", "seq2seq_lm"]), "runtime": st.sampled_from(["transformers", "ggml"]), "name_type": st.sampled_from(["dasherize", "lowercase"]), "timeout": st.integers( + min_value=3600 + ), "workers_per_resource": st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)), + } + return draw(st.builds(ModelSettings, **kwargs)) +def make_llm_config(cls_name: str, dunder_config: dict[str, t.Any] | ModelSettings, fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None, generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None,) -> type[openllm.LLMConfig]: + globs: dict[str, t.Any] = {"openllm": openllm} + _config_args: list[str] = [] + lines: list[str] = [f"class {cls_name}Config(openllm.LLMConfig):"] + for attr, value in dunder_config.items(): + _config_args.append(f'"{attr}": __attr_{attr}') + globs[f"_{cls_name}Config__attr_{attr}"] = value + lines.append(f' __config__ = {{ {", ".join(_config_args)} }}') + if fields is not None: + for field, type_, default in fields: + lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({default!r})") + if generation_fields is not None: + generation_lines = ["class GenerationConfig:"] + for field, default in generation_fields: + generation_lines.append(f" {field} = {default!r}") + lines.extend((" " + line for line in generation_lines)) -def make_llm_config( - cls_name: str, - dunder_config: dict[str, t.Any] | ModelSettings, - fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None, - generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None, -) -> type[openllm.LLMConfig]: - globs: dict[str, t.Any] = {"openllm": openllm} - _config_args: list[str] = [] - lines: list[str] = [f"class {cls_name}Config(openllm.LLMConfig):"] - for attr, value in dunder_config.items(): - _config_args.append(f'"{attr}": __attr_{attr}') - globs[f"_{cls_name}Config__attr_{attr}"] = value - lines.append(f' __config__ = {{ {", ".join(_config_args)} }}') - if fields is not None: - for field, type_, default in fields: - lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({default!r})") - if generation_fields is not None: - generation_lines = ["class GenerationConfig:"] - for field, default in generation_fields: - generation_lines.append(f" {field} = {default!r}") - lines.extend((" " + line for line in generation_lines)) + script = "\n".join(lines) - script = "\n".join(lines) + if openllm.utils.DEBUG: + logger.info("Generated class %s:\n%s", cls_name, script) - if openllm.utils.DEBUG: - logger.info("Generated class %s:\n%s", cls_name, script) + eval(compile(script, "name", "exec"), globs) - eval(compile(script, "name", "exec"), globs) - - return globs[f"{cls_name}Config"] + return globs[f"{cls_name}Config"] diff --git a/tests/client_test.py b/tests/client_test.py index 88650445..df31b722 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -17,7 +17,5 @@ from __future__ import annotations import openllm def test_import_client(): - assert len(openllm.client.__all__) == 4 - assert all( - hasattr(openllm.client, attr) for attr in ("AsyncGrpcClient", "GrpcClient", "AsyncHTTPClient", "HTTPClient") - ) + assert len(openllm.client.__all__) == 4 + assert all(hasattr(openllm.client, attr) for attr in ("AsyncGrpcClient", "GrpcClient", "AsyncHTTPClient", "HTTPClient")) diff --git a/tests/configuration_test.py b/tests/configuration_test.py index 6a937b9c..5b525009 100644 --- a/tests/configuration_test.py +++ b/tests/configuration_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """All configuration-related tests for openllm.LLMConfig. This will include testing for ModelEnv construction and parsing environment variables. """ @@ -41,213 +40,137 @@ from ._strategies._configuration import model_settings logger = logging.getLogger(__name__) if t.TYPE_CHECKING: - DictStrAny = dict[str, t.Any] + DictStrAny = dict[str, t.Any] else: - DictStrAny = dict - + DictStrAny = dict # XXX: @aarnphm fixes TypedDict behaviour in 3.11 -@pytest.mark.skipif( - sys.version_info[:2] == (3, 11), reason="TypedDict in 3.11 behaves differently, so we need to fix this" -) +@pytest.mark.skipif(sys.version_info[:2] == (3, 11), reason="TypedDict in 3.11 behaves differently, so we need to fix this") def test_missing_default(): - with pytest.raises(ValueError, match="Missing required fields *"): - make_llm_config("MissingDefaultId", {"name_type": "lowercase", "requirements": ["bentoml"]}) - with pytest.raises(ValueError, match="Missing required fields *"): - make_llm_config("MissingModelId", {"default_id": "huggingface/t5-tiny-testing", "requirements": ["bentoml"]}) - with pytest.raises(ValueError, match="Missing required fields *"): - make_llm_config( - "MissingArchitecture", - { - "default_id": "huggingface/t5-tiny-testing", - "model_ids": ["huggingface/t5-tiny-testing"], - "requirements": ["bentoml"], - }, - ) - + with pytest.raises(ValueError, match="Missing required fields *"): + make_llm_config("MissingDefaultId", {"name_type": "lowercase", "requirements": ["bentoml"]}) + with pytest.raises(ValueError, match="Missing required fields *"): + make_llm_config("MissingModelId", {"default_id": "huggingface/t5-tiny-testing", "requirements": ["bentoml"]}) + with pytest.raises(ValueError, match="Missing required fields *"): + make_llm_config("MissingArchitecture", {"default_id": "huggingface/t5-tiny-testing", "model_ids": ["huggingface/t5-tiny-testing"], "requirements": ["bentoml"],},) def test_forbidden_access(): - cl_ = make_llm_config( - "ForbiddenAccess", - { - "default_id": "huggingface/t5-tiny-testing", - "model_ids": ["huggingface/t5-tiny-testing", "bentoml/t5-tiny-testing"], - "architecture": "PreTrainedModel", - "requirements": ["bentoml"], - }, - ) + cl_ = make_llm_config("ForbiddenAccess", {"default_id": "huggingface/t5-tiny-testing", "model_ids": ["huggingface/t5-tiny-testing", "bentoml/t5-tiny-testing"], "architecture": "PreTrainedModel", "requirements": ["bentoml"],},) - assert pytest.raises( - openllm.exceptions.ForbiddenAttributeError, - cl_.__getattribute__, - cl_(), - "__config__", - ) - assert pytest.raises( - openllm.exceptions.ForbiddenAttributeError, - cl_.__getattribute__, - cl_(), - "GenerationConfig", - ) - assert pytest.raises( - openllm.exceptions.ForbiddenAttributeError, - cl_.__getattribute__, - cl_(), - "SamplingParams", - ) - - assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig) + assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "__config__",) + assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "GenerationConfig",) + assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "SamplingParams",) + assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig) @given(model_settings()) def test_class_normal_gen(gen_settings: ModelSettings): - assume(gen_settings["default_id"] and all(i for i in gen_settings["model_ids"])) - cl_: type[openllm.LLMConfig] = make_llm_config("NotFullLLM", gen_settings) - assert issubclass(cl_, openllm.LLMConfig) - for key in gen_settings: - assert object.__getattribute__(cl_, f"__openllm_{key}__") == gen_settings.__getitem__(key) - + assume(gen_settings["default_id"] and all(i for i in gen_settings["model_ids"])) + cl_: type[openllm.LLMConfig] = make_llm_config("NotFullLLM", gen_settings) + assert issubclass(cl_, openllm.LLMConfig) + for key in gen_settings: + assert object.__getattribute__(cl_, f"__openllm_{key}__") == gen_settings.__getitem__(key) @given(model_settings(), st.integers()) def test_simple_struct_dump(gen_settings: ModelSettings, field1: int): - cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),)) - assert cl_().model_dump()["field1"] == field1 - + cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),)) + assert cl_().model_dump()["field1"] == field1 @given(model_settings(), st.integers()) def test_config_derivation(gen_settings: ModelSettings, field1: int): - cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),)) - new_cls = cl_.model_derivate("DerivedLLM", default_id="asdfasdf") - assert new_cls.__openllm_default_id__ == "asdfasdf" - + cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),)) + new_cls = cl_.model_derivate("DerivedLLM", default_id="asdfasdf") + assert new_cls.__openllm_default_id__ == "asdfasdf" @given(model_settings()) def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings): - cl_ = make_llm_config("AttrsProtocolLLM", gen_settings) - assert attr.has(cl_) + cl_ = make_llm_config("AttrsProtocolLLM", gen_settings) + assert attr.has(cl_) +@given(model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),) +def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float): + cl_ = make_llm_config("ComplexLLM", gen_settings, fields=(("field1", "float", field1),), generation_fields=(("temperature", temperature),),) + sent = cl_() + assert sent.model_dump()["field1"] == field1 + assert sent.model_dump()["generation_config"]["temperature"] == temperature + assert sent.model_dump(flatten=True)["field1"] == field1 + assert sent.model_dump(flatten=True)["temperature"] == temperature -@given( - model_settings(), - st.integers(max_value=283473), - st.floats(min_value=0.0, max_value=1.0), - st.integers(max_value=283473), - st.floats(min_value=0.0, max_value=1.0), -) -def test_complex_struct_dump( - gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float -): - cl_ = make_llm_config( - "ComplexLLM", - gen_settings, - fields=(("field1", "float", field1),), - generation_fields=(("temperature", temperature),), - ) - sent = cl_() - assert sent.model_dump()["field1"] == field1 - assert sent.model_dump()["generation_config"]["temperature"] == temperature - assert sent.model_dump(flatten=True)["field1"] == field1 - assert sent.model_dump(flatten=True)["temperature"] == temperature - - passed = cl_(field1=input_field1, temperature=input_temperature) - assert passed.model_dump()["field1"] == input_field1 - assert passed.model_dump()["generation_config"]["temperature"] == input_temperature - assert passed.model_dump(flatten=True)["field1"] == input_field1 - assert passed.model_dump(flatten=True)["temperature"] == input_temperature - - pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1) - assert pas_nested.model_dump()["field1"] == input_field1 - assert pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature + passed = cl_(field1=input_field1, temperature=input_temperature) + assert passed.model_dump()["field1"] == input_field1 + assert passed.model_dump()["generation_config"]["temperature"] == input_temperature + assert passed.model_dump(flatten=True)["field1"] == input_field1 + assert passed.model_dump(flatten=True)["temperature"] == input_temperature + pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1) + assert pas_nested.model_dump()["field1"] == input_field1 + assert pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature @contextlib.contextmanager def patch_env(**attrs: t.Any): - with mock.patch.dict(os.environ, attrs, clear=True): - yield - + with mock.patch.dict(os.environ, attrs, clear=True): + yield def test_struct_envvar(): - with patch_env( - **{ - field_env_key("env_llm", "field1"): "4", - field_env_key("env_llm", "temperature", suffix="generation"): "0.2", - } - ): + with patch_env(**{field_env_key("env_llm", "field1"): "4", field_env_key("env_llm", "temperature", suffix="generation"): "0.2",}): - class EnvLLM(openllm.LLMConfig): - __config__ = { - "default_id": "asdfasdf", - "model_ids": ["asdf", "asdfasdfads"], - "architecture": "PreTrainedModel", - } - field1: int = 2 + class EnvLLM(openllm.LLMConfig): + __config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel",} + field1: int = 2 - class GenerationConfig: - temperature: float = 0.8 + class GenerationConfig: + temperature: float = 0.8 - sent = EnvLLM.model_construct_env() - assert sent.field1 == 4 - assert sent["temperature"] == 0.2 - - overwrite_default = EnvLLM() - assert overwrite_default.field1 == 4 - assert overwrite_default["temperature"] == 0.2 + sent = EnvLLM.model_construct_env() + assert sent.field1 == 4 + assert sent["temperature"] == 0.2 + overwrite_default = EnvLLM() + assert overwrite_default.field1 == 4 + assert overwrite_default["temperature"] == 0.2 def test_struct_provided_fields(): - class EnvLLM(openllm.LLMConfig): - __config__ = { - "default_id": "asdfasdf", - "model_ids": ["asdf", "asdfasdfads"], - "architecture": "PreTrainedModel", - } - field1: int = 2 + class EnvLLM(openllm.LLMConfig): + __config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel",} + field1: int = 2 - class GenerationConfig: - temperature: float = 0.8 - - sent = EnvLLM.model_construct_env(field1=20, temperature=0.4) - assert sent.field1 == 20 - assert sent.generation_config.temperature == 0.4 + class GenerationConfig: + temperature: float = 0.8 + sent = EnvLLM.model_construct_env(field1=20, temperature=0.4) + assert sent.field1 == 20 + assert sent.generation_config.temperature == 0.4 def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mk: - mk.setenv(field_env_key("overwrite_with_env_available", "field1"), str(4.0)) - mk.setenv(field_env_key("overwrite_with_env_available", "temperature", suffix="generation"), str(0.2)) - sent = make_llm_config( - "OverwriteWithEnvAvailable", - {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel"}, - fields=(("field1", "float", 3.0),), - ).model_construct_env(field1=20.0, temperature=0.4) - assert sent.generation_config.temperature == 0.4 - assert sent.field1 == 20.0 - + with monkeypatch.context() as mk: + mk.setenv(field_env_key("overwrite_with_env_available", "field1"), str(4.0)) + mk.setenv(field_env_key("overwrite_with_env_available", "temperature", suffix="generation"), str(0.2)) + sent = make_llm_config("OverwriteWithEnvAvailable", {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel"}, fields=(("field1", "float", 3.0),),).model_construct_env(field1=20.0, temperature=0.4) + assert sent.generation_config.temperature == 0.4 + assert sent.field1 == 20.0 @given(model_settings()) @pytest.mark.parametrize(("return_dict", "typ"), [(True, DictStrAny), (False, transformers.GenerationConfig)]) def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings): - cl_ = make_llm_config("ConversionLLM", gen_settings) - assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ) - + cl_ = make_llm_config("ConversionLLM", gen_settings) + assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ) @given(model_settings()) def test_click_conversion(gen_settings: ModelSettings): - # currently our conversion omit Union type. - def cli_mock(**attrs: t.Any): - return attrs - - cl_ = make_llm_config("ClickConversionLLM", gen_settings) - wrapped = cl_.to_click_options(cli_mock) - filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union} - click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith("fake_")] - assert len(filtered) == len(click_options_filtered) + # currently our conversion omit Union type. + def cli_mock(**attrs: t.Any): + return attrs + cl_ = make_llm_config("ClickConversionLLM", gen_settings) + wrapped = cl_.to_click_options(cli_mock) + filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union} + click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith("fake_")] + assert len(filtered) == len(click_options_filtered) @pytest.mark.parametrize("model_name", openllm.CONFIG_MAPPING.keys()) def test_configuration_dict_protocol(model_name: str): - config = openllm.AutoConfig.for_model(model_name) - assert isinstance(config.items(), list) - assert isinstance(config.keys(), list) - assert isinstance(config.values(), list) - assert isinstance(dict(config), dict) + config = openllm.AutoConfig.for_model(model_name) + assert isinstance(config.items(), list) + assert isinstance(config.keys(), list) + assert isinstance(config.values(), list) + assert isinstance(dict(config), dict) diff --git a/tests/conftest.py b/tests/conftest.py index 32f79858..d2d96c3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,53 +22,33 @@ import pytest import openllm if t.TYPE_CHECKING: - from openllm._types import LiteralRuntime + from openllm._types import LiteralRuntime +_FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m", "baichuan": "baichuan-inc/Baichuan-7B",} +_PROMPT_MAPPING = {"qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?",} -_FRAMEWORK_MAPPING = { - "flan_t5": "google/flan-t5-small", - "opt": "facebook/opt-125m", - "baichuan": "baichuan-inc/Baichuan-7B", -} -_PROMPT_MAPPING = { - "qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?", -} +def parametrise_local_llm(model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]: + if model not in _FRAMEWORK_MAPPING: + pytest.skip(f"'{model}' is not yet supported in framework testing.") + runtime_impl: tuple[LiteralRuntime, ...] = tuple() + if model in openllm.MODEL_MAPPING_NAMES: + runtime_impl += ("pt",) + if model in openllm.MODEL_FLAX_MAPPING_NAMES: + runtime_impl += ("flax",) + if model in openllm.MODEL_TF_MAPPING_NAMES: + runtime_impl += ("tf",) -def parametrise_local_llm( - model: str, -) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]: - if model not in _FRAMEWORK_MAPPING: - pytest.skip(f"'{model}' is not yet supported in framework testing.") - - runtime_impl: tuple[LiteralRuntime, ...] = tuple() - if model in openllm.MODEL_MAPPING_NAMES: - runtime_impl += ("pt",) - if model in openllm.MODEL_FLAX_MAPPING_NAMES: - runtime_impl += ("flax",) - if model in openllm.MODEL_TF_MAPPING_NAMES: - runtime_impl += ("tf",) - - for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()): - llm = openllm.Runner( - model, - model_id=_FRAMEWORK_MAPPING[model], - ensure_available=True, - implementation=framework, - init_local=True, - ) - yield prompt, llm - + for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()): + llm = openllm.Runner(model, model_id=_FRAMEWORK_MAPPING[model], ensure_available=True, implementation=framework, init_local=True,) + yield prompt, llm def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: - if os.getenv("GITHUB_ACTIONS") is None: - if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames: - metafunc.parametrize( - "prompt,llm", [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])] - ) - + if os.getenv("GITHUB_ACTIONS") is None: + if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames: + metafunc.parametrize("prompt,llm", [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]) def pytest_sessionfinish(session: pytest.Session, exitstatus: int): - # If no tests are collected, pytest exists with code 5, which makes the CI fail. - if exitstatus == 5: - session.exitstatus = 0 + # If no tests are collected, pytest exists with code 5, which makes the CI fail. + if exitstatus == 5: + session.exitstatus = 0 diff --git a/tests/models/conftest.py b/tests/models/conftest.py index acd85a30..9c3f9d3a 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -37,277 +37,207 @@ from openllm._llm import normalise_model_name logger = logging.getLogger(__name__) if t.TYPE_CHECKING: - import subprocess + import subprocess - from openllm_client.runtimes.base import BaseAsyncClient - from syrupy.assertion import SnapshotAssertion - from syrupy.types import PropertyFilter - from syrupy.types import PropertyMatcher - from syrupy.types import SerializableData - from syrupy.types import SerializedData + from openllm_client.runtimes.base import BaseAsyncClient + from syrupy.assertion import SnapshotAssertion + from syrupy.types import PropertyFilter + from syrupy.types import PropertyMatcher + from syrupy.types import SerializableData + from syrupy.types import SerializedData - from openllm._configuration import GenerationConfig - from openllm._types import DictStrAny - from openllm._types import ListAny + from openllm._configuration import GenerationConfig + from openllm._types import DictStrAny + from openllm._types import ListAny else: - DictStrAny = dict - ListAny = list - + DictStrAny = dict + ListAny = list class ResponseComparator(JSONSnapshotExtension): - def serialize( - self, - data: SerializableData, - *, - exclude: PropertyFilter | None = None, - matcher: PropertyMatcher | None = None, - ) -> SerializedData: - if openllm.utils.LazyType(ListAny).isinstance(data): - data = [d.unmarshaled for d in data] - else: - data = data.unmarshaled - data = self._filter(data=data, depth=0, path=(), exclude=exclude, matcher=matcher) - return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode() + def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData: + if openllm.utils.LazyType(ListAny).isinstance(data): + data = [d.unmarshaled for d in data] + else: + data = data.unmarshaled + data = self._filter(data=data, depth=0, path=(), exclude=exclude, matcher=matcher) + return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode() - def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool: - def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]: - try: - data = orjson.loads(data) - except orjson.JSONDecodeError as err: - raise ValueError(f"Failed to decode JSON data: {data}") from err - if openllm.utils.LazyType(DictStrAny).isinstance(data): - return openllm.GenerationOutput(**data) - elif openllm.utils.LazyType(ListAny).isinstance(data): - return [openllm.GenerationOutput(**d) for d in data] - else: - raise NotImplementedError(f"Data {data} has unsupported type.") + def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool: + def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]: + try: + data = orjson.loads(data) + except orjson.JSONDecodeError as err: + raise ValueError(f"Failed to decode JSON data: {data}") from err + if openllm.utils.LazyType(DictStrAny).isinstance(data): + return openllm.GenerationOutput(**data) + elif openllm.utils.LazyType(ListAny).isinstance(data): + return [openllm.GenerationOutput(**d) for d in data] + else: + raise NotImplementedError(f"Data {data} has unsupported type.") - serialized_data = convert_data(serialized_data) - snapshot_data = convert_data(snapshot_data) + serialized_data = convert_data(serialized_data) + snapshot_data = convert_data(snapshot_data) - if openllm.utils.LazyType(ListAny).isinstance(serialized_data): - serialized_data = [serialized_data] - if openllm.utils.LazyType(ListAny).isinstance(snapshot_data): - snapshot_data = [snapshot_data] + if openllm.utils.LazyType(ListAny).isinstance(serialized_data): + serialized_data = [serialized_data] + if openllm.utils.LazyType(ListAny).isinstance(snapshot_data): + snapshot_data = [snapshot_data] - def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool: - return s == t + def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool: + return s == t - def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool: - return ( - len(s.responses) == len(t.responses) - and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) - and eq_config(s.marshaled_config, t.marshaled_config) - ) - - return len(serialized_data) == len(snapshot_data) and all( - [eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)] - ) + def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool: + return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config)) + return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)]) @pytest.fixture() def response_snapshot(snapshot: SnapshotAssertion): - return snapshot.use_extension(ResponseComparator) - + return snapshot.use_extension(ResponseComparator) @attr.define(init=False) class _Handle(ABC): - port: int - deployment_mode: t.Literal["container", "local"] + port: int + deployment_mode: t.Literal["container", "local"] - client: BaseAsyncClient[t.Any] = attr.field(init=False) + client: BaseAsyncClient[t.Any] = attr.field(init=False) - if t.TYPE_CHECKING: + if t.TYPE_CHECKING: - def __attrs_init__(self, *args: t.Any, **attrs: t.Any): - ... + def __attrs_init__(self, *args: t.Any, **attrs: t.Any): + ... - def __attrs_post_init__(self): - self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}") + def __attrs_post_init__(self): + self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}") - @abstractmethod - def status(self) -> bool: - raise NotImplementedError - - async def health(self, timeout: int = 240): - start_time = time.time() - while time.time() - start_time < timeout: - if not self.status(): - raise RuntimeError(f"Failed to initialise {self.__class__.__name__}") - await self.client.health() - try: - await self.client.query("sanity") - return - except Exception: - time.sleep(1) - raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.") + @abstractmethod + def status(self) -> bool: + raise NotImplementedError + async def health(self, timeout: int = 240): + start_time = time.time() + while time.time() - start_time < timeout: + if not self.status(): + raise RuntimeError(f"Failed to initialise {self.__class__.__name__}") + await self.client.health() + try: + await self.client.query("sanity") + return + except Exception: + time.sleep(1) + raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.") @attr.define(init=False) class LocalHandle(_Handle): - process: subprocess.Popen[bytes] + process: subprocess.Popen[bytes] - def __init__( - self, - process: subprocess.Popen[bytes], - port: int, - deployment_mode: t.Literal["container", "local"], - ): - self.__attrs_init__(port, deployment_mode, process) - - def status(self) -> bool: - return self.process.poll() is None + def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal["container", "local"],): + self.__attrs_init__(port, deployment_mode, process) + def status(self) -> bool: + return self.process.poll() is None class HandleProtocol(t.Protocol): - @contextlib.contextmanager - def __call__( - *, - model: str, - model_id: str, - image_tag: str, - quantize: t.AnyStr | None = None, - ) -> t.Generator[_Handle, None, None]: - ... - + @contextlib.contextmanager + def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]: + ... @attr.define(init=False) class DockerHandle(_Handle): - container_name: str - docker_client: docker.DockerClient + container_name: str + docker_client: docker.DockerClient - def __init__( - self, - docker_client: docker.DockerClient, - container_name: str, - port: int, - deployment_mode: t.Literal["container", "local"], - ): - self.__attrs_init__(port, deployment_mode, container_name, docker_client) - - def status(self) -> bool: - container = self.docker_client.containers.get(self.container_name) - return container.status in ["running", "created"] + def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal["container", "local"],): + self.__attrs_init__(port, deployment_mode, container_name, docker_client) + def status(self) -> bool: + container = self.docker_client.containers.get(self.container_name) + return container.status in ["running", "created"] @contextlib.contextmanager -def _local_handle( - model: str, - model_id: str, - image_tag: str, - deployment_mode: t.Literal["container", "local"], - quantize: t.Literal["int8", "int4", "gptq"] | None = None, - *, - _serve_grpc: bool = False, -): - with openllm.utils.reserve_free_port() as port: - pass +def _local_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,): + with openllm.utils.reserve_free_port() as port: + pass - if not _serve_grpc: - proc = openllm.start( - model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True - ) - else: - proc = openllm.start_grpc( - model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True - ) + if not _serve_grpc: + proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True) + else: + proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True) - yield LocalHandle(proc, port, deployment_mode) - proc.terminate() - proc.wait(60) + yield LocalHandle(proc, port, deployment_mode) + proc.terminate() + proc.wait(60) - process_output = proc.stdout.read() - print(process_output, file=sys.stderr) - - proc.stdout.close() - if proc.stderr: - proc.stderr.close() + process_output = proc.stdout.read() + print(process_output, file=sys.stderr) + proc.stdout.close() + if proc.stderr: + proc.stderr.close() @contextlib.contextmanager -def _container_handle( - model: str, - model_id: str, - image_tag: str, - deployment_mode: t.Literal["container", "local"], - quantize: t.Literal["int8", "int4", "gptq"] | None = None, - *, - _serve_grpc: bool = False, -): - envvar = openllm.utils.EnvVarMixin(model) - - with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port: - pass - container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_") - client = docker.from_env() - try: - container = client.containers.get(container_name) - container.stop() - container.wait() - container.remove() - except docker.errors.NotFound: - pass - - args = ["serve" if not _serve_grpc else "serve-grpc"] - - env: DictStrAny = {} - - if quantize is not None: - env[envvar.quantize] = quantize - - gpus = openllm.utils.device_count() or -1 - devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None - - container = client.containers.run( - image_tag, - command=args, - name=container_name, - environment=env, - auto_remove=False, - detach=True, - device_requests=devs, - ports={"3000/tcp": port, "3001/tcp": prom_port}, - ) - - yield DockerHandle(client, container.name, port, deployment_mode) - - try: - container.stop() - container.wait() - except docker.errors.NotFound: - pass - - container_output = container.logs().decode("utf-8") - print(container_output, file=sys.stderr) +def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,): + envvar = openllm.utils.EnvVarMixin(model) + with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port: + pass + container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_") + client = docker.from_env() + try: + container = client.containers.get(container_name) + container.stop() + container.wait() container.remove() + except docker.errors.NotFound: + pass + args = ["serve" if not _serve_grpc else "serve-grpc"] + + env: DictStrAny = {} + + if quantize is not None: + env[envvar.quantize] = quantize + + gpus = openllm.utils.device_count() or -1 + devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None + + container = client.containers.run(image_tag, command=args, name=container_name, environment=env, auto_remove=False, detach=True, device_requests=devs, ports={"3000/tcp": port, "3001/tcp": prom_port},) + + yield DockerHandle(client, container.name, port, deployment_mode) + + try: + container.stop() + container.wait() + except docker.errors.NotFound: + pass + + container_output = container.logs().decode("utf-8") + print(container_output, file=sys.stderr) + + container.remove() @pytest.fixture(scope="session", autouse=True) def clean_context() -> t.Generator[contextlib.ExitStack, None, None]: - stack = contextlib.ExitStack() - yield stack - stack.close() - + stack = contextlib.ExitStack() + yield stack + stack.close() @pytest.fixture(scope="module") def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]: - loop = asyncio.get_event_loop() - yield loop - loop.close() - + loop = asyncio.get_event_loop() + yield loop + loop.close() @pytest.fixture(params=["container", "local"], scope="session") def deployment_mode(request: pytest.FixtureRequest) -> str: - return request.param - + return request.param @pytest.fixture(scope="module") def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]): - if deployment_mode == "container": - return functools.partial(_container_handle, deployment_mode=deployment_mode) - elif deployment_mode == "local": - return functools.partial(_local_handle, deployment_mode=deployment_mode) - else: - raise ValueError(f"Unknown deployment mode: {deployment_mode}") + if deployment_mode == "container": + return functools.partial(_container_handle, deployment_mode=deployment_mode) + elif deployment_mode == "local": + return functools.partial(_local_handle, deployment_mode=deployment_mode) + else: + raise ValueError(f"Unknown deployment mode: {deployment_mode}") diff --git a/tests/models/flan_t5_test.py b/tests/models/flan_t5_test.py index 2e149ad7..37ed2815 100644 --- a/tests/models/flan_t5_test.py +++ b/tests/models/flan_t5_test.py @@ -20,40 +20,30 @@ import pytest import openllm if t.TYPE_CHECKING: - import contextlib - - from .conftest import HandleProtocol - from .conftest import ResponseComparator - from .conftest import _Handle + import contextlib + from .conftest import HandleProtocol + from .conftest import ResponseComparator + from .conftest import _Handle model = "flan_t5" model_id = "google/flan-t5-small" - @pytest.fixture(scope="module") -def flan_t5_handle( - handler: HandleProtocol, - deployment_mode: t.Literal["container", "local"], - clean_context: contextlib.ExitStack, -): - with openllm.testing.prepare( - model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context - ) as image_tag: - with handler(model=model, model_id=model_id, image_tag=image_tag) as handle: - yield handle - +def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,): + with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag: + with handler(model=model, model_id=model_id, image_tag=image_tag) as handle: + yield handle @pytest.fixture(scope="module") async def flan_t5(flan_t5_handle: _Handle): - await flan_t5_handle.health(240) - return flan_t5_handle.client - + await flan_t5_handle.health(240) + return flan_t5_handle.client @pytest.mark.asyncio() async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator): - client = await flan_t5 - response = await client.query("What is the meaning of life?", max_new_tokens=10, top_p=0.9, return_response="attrs") + client = await flan_t5 + response = await client.query("What is the meaning of life?", max_new_tokens=10, top_p=0.9, return_response="attrs") - assert response.configuration["generation_config"]["max_new_tokens"] == 10 - assert response == response_snapshot + assert response.configuration["generation_config"]["max_new_tokens"] == 10 + assert response == response_snapshot diff --git a/tests/models/opt_test.py b/tests/models/opt_test.py index bfbd66ac..98f99fcd 100644 --- a/tests/models/opt_test.py +++ b/tests/models/opt_test.py @@ -19,40 +19,30 @@ import pytest import openllm if t.TYPE_CHECKING: - import contextlib - - from .conftest import HandleProtocol - from .conftest import ResponseComparator - from .conftest import _Handle + import contextlib + from .conftest import HandleProtocol + from .conftest import ResponseComparator + from .conftest import _Handle model = "opt" model_id = "facebook/opt-125m" - @pytest.fixture(scope="module") -def opt_125m_handle( - handler: HandleProtocol, - deployment_mode: t.Literal["container", "local"], - clean_context: contextlib.ExitStack, -): - with openllm.testing.prepare( - model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context - ) as image_tag: - with handler(model=model, model_id=model_id, image_tag=image_tag) as handle: - yield handle - +def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,): + with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag: + with handler(model=model, model_id=model_id, image_tag=image_tag) as handle: + yield handle @pytest.fixture(scope="module") async def opt_125m(opt_125m_handle: _Handle): - await opt_125m_handle.health(240) - return opt_125m_handle.client - + await opt_125m_handle.health(240) + return opt_125m_handle.client @pytest.mark.asyncio() async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator): - client = await opt_125m - response = await client.query("What is Deep learning?", max_new_tokens=20, return_response="attrs") + client = await opt_125m + response = await client.query("What is Deep learning?", max_new_tokens=20, return_response="attrs") - assert response.configuration["generation_config"]["max_new_tokens"] == 20 - assert response == response_snapshot + assert response.configuration["generation_config"]["max_new_tokens"] == 20 + assert response == response_snapshot diff --git a/tests/models_test.py b/tests/models_test.py index 850f2589..1fbf2362 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -19,25 +19,22 @@ import typing as t import pytest if t.TYPE_CHECKING: - import openllm - + import openllm @pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI") def test_flan_t5_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): - assert llm(prompt) - - assert llm(prompt, temperature=0.8, top_p=0.23) + assert llm(prompt) + assert llm(prompt, temperature=0.8, top_p=0.23) @pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI") def test_opt_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): - assert llm(prompt) - - assert llm(prompt, temperature=0.9, top_k=8) + assert llm(prompt) + assert llm(prompt, temperature=0.9, top_k=8) @pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI") def test_baichuan_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): - assert llm(prompt) + assert llm(prompt) - assert llm(prompt, temperature=0.95) + assert llm(prompt, temperature=0.95) diff --git a/tests/package_test.py b/tests/package_test.py index c8cbe4c7..00c67a48 100644 --- a/tests/package_test.py +++ b/tests/package_test.py @@ -23,55 +23,44 @@ import openllm from bentoml._internal.configuration.containers import BentoMLContainer if t.TYPE_CHECKING: - from pathlib import Path - + from pathlib import Path HF_INTERNAL_T5_TESTING = "hf-internal-testing/tiny-random-t5" -actions_xfail = functools.partial( - pytest.mark.xfail, - condition=os.getenv("GITHUB_ACTIONS") is not None, - reason="Marking GitHub Actions to xfail due to flakiness and building environment not isolated.", -) - +actions_xfail = functools.partial(pytest.mark.xfail, condition=os.getenv("GITHUB_ACTIONS") is not None, reason="Marking GitHub Actions to xfail due to flakiness and building environment not isolated.",) @actions_xfail def test_general_build_with_internal_testing(): - bento_store = BentoMLContainer.bento_store.get() + bento_store = BentoMLContainer.bento_store.get() - llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING) - bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING) + llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING) + bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING) - assert llm.llm_type == bento.info.labels["_type"] - assert llm.config["env"]["framework_value"] == bento.info.labels["_framework"] - - bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING) - assert len(bento_store.list(bento.tag)) == 1 + assert llm.llm_type == bento.info.labels["_type"] + assert llm.config["env"]["framework_value"] == bento.info.labels["_framework"] + bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING) + assert len(bento_store.list(bento.tag)) == 1 @actions_xfail def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory): - local_path = tmp_path_factory.mktemp("local_t5") - llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True) + local_path = tmp_path_factory.mktemp("local_t5") + llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True) - if llm.bettertransformer: - llm.__llm_model__ = llm.model.reverse_bettertransformer() + if llm.bettertransformer: + llm.__llm_model__ = llm.model.reverse_bettertransformer() - llm.save_pretrained(local_path) - - assert openllm.build("flan-t5", model_id=local_path.resolve().__fspath__(), model_version="local") + llm.save_pretrained(local_path) + assert openllm.build("flan-t5", model_id=local_path.resolve().__fspath__(), model_version="local") @pytest.fixture() def dockerfile_template(tmp_path_factory: pytest.TempPathFactory): - file = tmp_path_factory.mktemp("dockerfiles") / "Dockerfile.template" - file.write_text( - "{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}" - ) - return file - + file = tmp_path_factory.mktemp("dockerfiles") / "Dockerfile.template" + file.write_text("{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}") + return file @pytest.mark.usefixtures("dockerfile_template") @actions_xfail def test_build_with_custom_dockerfile(dockerfile_template: Path): - assert openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING, dockerfile_template=str(dockerfile_template)) + assert openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING, dockerfile_template=str(dockerfile_template)) diff --git a/tests/strategies_test.py b/tests/strategies_test.py index bde7454e..d1429e57 100644 --- a/tests/strategies_test.py +++ b/tests/strategies_test.py @@ -19,7 +19,7 @@ import typing as t import pytest if t.TYPE_CHECKING: - from _pytest.monkeypatch import MonkeyPatch + from _pytest.monkeypatch import MonkeyPatch import bentoml from bentoml._internal.resource import get_resource @@ -28,186 +28,162 @@ from openllm._strategies import CascadingResourceStrategy from openllm._strategies import NvidiaGpuResource def test_nvidia_gpu_resource_from_env(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - mcls.setenv("CUDA_VISIBLE_DEVICES", "0,1") - resource = NvidiaGpuResource.from_system() - assert len(resource) == 2 - assert resource == ["0", "1"] - mcls.delenv("CUDA_VISIBLE_DEVICES") - + with monkeypatch.context() as mcls: + mcls.setenv("CUDA_VISIBLE_DEVICES", "0,1") + resource = NvidiaGpuResource.from_system() + assert len(resource) == 2 + assert resource == ["0", "1"] + mcls.delenv("CUDA_VISIBLE_DEVICES") def test_nvidia_gpu_cutoff_minus(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - mcls.setenv("CUDA_VISIBLE_DEVICES", "0,2,-1,1") - resource = NvidiaGpuResource.from_system() - assert len(resource) == 2 - assert resource == ["0", "2"] - mcls.delenv("CUDA_VISIBLE_DEVICES") - + with monkeypatch.context() as mcls: + mcls.setenv("CUDA_VISIBLE_DEVICES", "0,2,-1,1") + resource = NvidiaGpuResource.from_system() + assert len(resource) == 2 + assert resource == ["0", "2"] + mcls.delenv("CUDA_VISIBLE_DEVICES") def test_nvidia_gpu_neg_val(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - mcls.setenv("CUDA_VISIBLE_DEVICES", "-1") - resource = NvidiaGpuResource.from_system() - assert len(resource) == 0 - assert resource == [] - mcls.delenv("CUDA_VISIBLE_DEVICES") - + with monkeypatch.context() as mcls: + mcls.setenv("CUDA_VISIBLE_DEVICES", "-1") + resource = NvidiaGpuResource.from_system() + assert len(resource) == 0 + assert resource == [] + mcls.delenv("CUDA_VISIBLE_DEVICES") def test_nvidia_gpu_parse_literal(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43-ac33420d4628") - resource = NvidiaGpuResource.from_system() - assert len(resource) == 1 - assert resource == ["GPU-5ebe9f43-ac33420d4628"] - mcls.delenv("CUDA_VISIBLE_DEVICES") - with monkeypatch.context() as mcls: - mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,GPU-ac33420d4628") - resource = NvidiaGpuResource.from_system() - assert len(resource) == 2 - assert resource == ["GPU-5ebe9f43", "GPU-ac33420d4628"] - mcls.delenv("CUDA_VISIBLE_DEVICES") - with monkeypatch.context() as mcls: - mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,-1,GPU-ac33420d4628") - resource = NvidiaGpuResource.from_system() - assert len(resource) == 1 - assert resource == ["GPU-5ebe9f43"] - mcls.delenv("CUDA_VISIBLE_DEVICES") - with monkeypatch.context() as mcls: - mcls.setenv("CUDA_VISIBLE_DEVICES", "MIG-GPU-5ebe9f43-ac33420d4628") - resource = NvidiaGpuResource.from_system() - assert len(resource) == 1 - assert resource == ["MIG-GPU-5ebe9f43-ac33420d4628"] - mcls.delenv("CUDA_VISIBLE_DEVICES") - + with monkeypatch.context() as mcls: + mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43-ac33420d4628") + resource = NvidiaGpuResource.from_system() + assert len(resource) == 1 + assert resource == ["GPU-5ebe9f43-ac33420d4628"] + mcls.delenv("CUDA_VISIBLE_DEVICES") + with monkeypatch.context() as mcls: + mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,GPU-ac33420d4628") + resource = NvidiaGpuResource.from_system() + assert len(resource) == 2 + assert resource == ["GPU-5ebe9f43", "GPU-ac33420d4628"] + mcls.delenv("CUDA_VISIBLE_DEVICES") + with monkeypatch.context() as mcls: + mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,-1,GPU-ac33420d4628") + resource = NvidiaGpuResource.from_system() + assert len(resource) == 1 + assert resource == ["GPU-5ebe9f43"] + mcls.delenv("CUDA_VISIBLE_DEVICES") + with monkeypatch.context() as mcls: + mcls.setenv("CUDA_VISIBLE_DEVICES", "MIG-GPU-5ebe9f43-ac33420d4628") + resource = NvidiaGpuResource.from_system() + assert len(resource) == 1 + assert resource == ["MIG-GPU-5ebe9f43-ac33420d4628"] + mcls.delenv("CUDA_VISIBLE_DEVICES") @pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="skip GPUs test on CI") def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - # to make this tests works with system that has GPU - mcls.setenv("CUDA_VISIBLE_DEVICES", "") - assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests - - assert pytest.raises( - ValueError, - NvidiaGpuResource.validate, - [*NvidiaGpuResource.from_system(), 1], - ).match("Input list should be all string type.") - assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match( - "Input list should be all string type." - ) - assert pytest.raises(ValueError, NvidiaGpuResource.validate, ["GPU-5ebe9f43", "GPU-ac33420d4628"]).match( - "Failed to parse available GPUs UUID" - ) + with monkeypatch.context() as mcls: + # to make this tests works with system that has GPU + mcls.setenv("CUDA_VISIBLE_DEVICES", "") + assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests + assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],).match("Input list should be all string type.") + assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match("Input list should be all string type.") + assert pytest.raises(ValueError, NvidiaGpuResource.validate, ["GPU-5ebe9f43", "GPU-ac33420d4628"]).match("Failed to parse available GPUs UUID") def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - # to make this tests works with system that has GPU - mcls.setenv("CUDA_VISIBLE_DEVICES", "") - assert NvidiaGpuResource.from_spec(1) == ["0"] - assert NvidiaGpuResource.from_spec("5") == ["0", "1", "2", "3", "4"] - assert NvidiaGpuResource.from_spec(1) == ["0"] - assert NvidiaGpuResource.from_spec(2) == ["0", "1"] - assert NvidiaGpuResource.from_spec("3") == ["0", "1", "2"] - assert NvidiaGpuResource.from_spec([1, 3]) == ["1", "3"] - assert NvidiaGpuResource.from_spec(["1", "3"]) == ["1", "3"] - assert NvidiaGpuResource.from_spec(-1) == [] - assert NvidiaGpuResource.from_spec("-1") == [] - assert NvidiaGpuResource.from_spec("") == [] - assert NvidiaGpuResource.from_spec("-2") == [] - assert NvidiaGpuResource.from_spec("GPU-288347ab") == ["GPU-288347ab"] - assert NvidiaGpuResource.from_spec("GPU-288347ab,-1,GPU-ac33420d4628") == ["GPU-288347ab"] - assert NvidiaGpuResource.from_spec("GPU-288347ab,GPU-ac33420d4628") == ["GPU-288347ab", "GPU-ac33420d4628"] - assert NvidiaGpuResource.from_spec("MIG-GPU-288347ab") == ["MIG-GPU-288347ab"] - - with pytest.raises(TypeError): - NvidiaGpuResource.from_spec((1, 2, 3)) - with pytest.raises(TypeError): - NvidiaGpuResource.from_spec(1.5) - with pytest.raises(ValueError): - assert NvidiaGpuResource.from_spec(-2) + with monkeypatch.context() as mcls: + # to make this tests works with system that has GPU + mcls.setenv("CUDA_VISIBLE_DEVICES", "") + assert NvidiaGpuResource.from_spec(1) == ["0"] + assert NvidiaGpuResource.from_spec("5") == ["0", "1", "2", "3", "4"] + assert NvidiaGpuResource.from_spec(1) == ["0"] + assert NvidiaGpuResource.from_spec(2) == ["0", "1"] + assert NvidiaGpuResource.from_spec("3") == ["0", "1", "2"] + assert NvidiaGpuResource.from_spec([1, 3]) == ["1", "3"] + assert NvidiaGpuResource.from_spec(["1", "3"]) == ["1", "3"] + assert NvidiaGpuResource.from_spec(-1) == [] + assert NvidiaGpuResource.from_spec("-1") == [] + assert NvidiaGpuResource.from_spec("") == [] + assert NvidiaGpuResource.from_spec("-2") == [] + assert NvidiaGpuResource.from_spec("GPU-288347ab") == ["GPU-288347ab"] + assert NvidiaGpuResource.from_spec("GPU-288347ab,-1,GPU-ac33420d4628") == ["GPU-288347ab"] + assert NvidiaGpuResource.from_spec("GPU-288347ab,GPU-ac33420d4628") == ["GPU-288347ab", "GPU-ac33420d4628"] + assert NvidiaGpuResource.from_spec("MIG-GPU-288347ab") == ["MIG-GPU-288347ab"] + with pytest.raises(TypeError): + NvidiaGpuResource.from_spec((1, 2, 3)) + with pytest.raises(TypeError): + NvidiaGpuResource.from_spec(1.5) + with pytest.raises(ValueError): + assert NvidiaGpuResource.from_spec(-2) class GPURunnable(bentoml.Runnable): - SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu") - + SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu") def unvalidated_get_resource(x: dict[str, t.Any], y: str, validate: bool = False): - return get_resource(x, y, validate=validate) - + return get_resource(x, y, validate=validate) @pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"]) def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str): - monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource) - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 1) == 2 - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 2) == 4 - assert pytest.raises( - ValueError, - CascadingResourceStrategy.get_worker_count, - GPURunnable, - {gpu_type: 0}, - 1, - ).match("No known supported resource available for *") - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 1) == 2 - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 2) == 4 - - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 0.5) == 1 - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 9]}, 0.5) == 2 - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5) == 2 - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 5, 7, 8, 9]}, 0.4) == 2 + monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource) + assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 1) == 2 + assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 2) == 4 + assert pytest.raises(ValueError, CascadingResourceStrategy.get_worker_count, GPURunnable, {gpu_type: 0}, 1,).match("No known supported resource available for *") + assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 1) == 2 + assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 2) == 4 + assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 0.5) == 1 + assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 9]}, 0.5) == 2 + assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5) == 2 + assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 5, 7, 8, 9]}, 0.4) == 2 @pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"]) def test_cascade_strategy_worker_env(monkeypatch: MonkeyPatch, gpu_type: str): - monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource) + monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource) - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0) - assert envs.get("CUDA_VISIBLE_DEVICES") == "0" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1) - assert envs.get("CUDA_VISIBLE_DEVICES") == "1" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 1, 1) - assert envs.get("CUDA_VISIBLE_DEVICES") == "7" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0) + assert envs.get("CUDA_VISIBLE_DEVICES") == "0" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1) + assert envs.get("CUDA_VISIBLE_DEVICES") == "1" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 1, 1) + assert envs.get("CUDA_VISIBLE_DEVICES") == "7" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 0) - assert envs.get("CUDA_VISIBLE_DEVICES") == "0" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 1) - assert envs.get("CUDA_VISIBLE_DEVICES") == "0" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 2) - assert envs.get("CUDA_VISIBLE_DEVICES") == "1" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 1) - assert envs.get("CUDA_VISIBLE_DEVICES") == "2" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 2) - assert envs.get("CUDA_VISIBLE_DEVICES") == "7" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 0) + assert envs.get("CUDA_VISIBLE_DEVICES") == "0" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 1) + assert envs.get("CUDA_VISIBLE_DEVICES") == "0" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 2) + assert envs.get("CUDA_VISIBLE_DEVICES") == "1" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 1) + assert envs.get("CUDA_VISIBLE_DEVICES") == "2" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 2) + assert envs.get("CUDA_VISIBLE_DEVICES") == "7" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 0.5, 0) - assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 0.5, 0) + assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 0) - assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 1) - assert envs.get("CUDA_VISIBLE_DEVICES") == "8,9" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.25, 0) - assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7,8,9" - - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 0) - assert envs.get("CUDA_VISIBLE_DEVICES") == "2,6" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 1) - assert envs.get("CUDA_VISIBLE_DEVICES") == "7,8" - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 2) - assert envs.get("CUDA_VISIBLE_DEVICES") == "9" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 0) + assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 1) + assert envs.get("CUDA_VISIBLE_DEVICES") == "8,9" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.25, 0) + assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7,8,9" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 0) + assert envs.get("CUDA_VISIBLE_DEVICES") == "2,6" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 1) + assert envs.get("CUDA_VISIBLE_DEVICES") == "7,8" + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 2) + assert envs.get("CUDA_VISIBLE_DEVICES") == "9" @pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"]) def test_cascade_strategy_disabled_via_env(monkeypatch: MonkeyPatch, gpu_type: str): - monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource) + monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource) - monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "") - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0) - assert envs.get("CUDA_VISIBLE_DEVICES") == "" - monkeypatch.delenv("CUDA_VISIBLE_DEVICES") + monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "") + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0) + assert envs.get("CUDA_VISIBLE_DEVICES") == "" + monkeypatch.delenv("CUDA_VISIBLE_DEVICES") - monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "-1") - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1) - assert envs.get("CUDA_VISIBLE_DEVICES") == "-1" - monkeypatch.delenv("CUDA_VISIBLE_DEVICES") + monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "-1") + envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1) + assert envs.get("CUDA_VISIBLE_DEVICES") == "-1" + monkeypatch.delenv("CUDA_VISIBLE_DEVICES") diff --git a/tools/assert-model-table-latest b/tools/assert-model-table-latest index af19d26e..5eb99b7c 100755 --- a/tools/assert-model-table-latest +++ b/tools/assert-model-table-latest @@ -8,32 +8,23 @@ import sys from markdown_it import MarkdownIt - md = MarkdownIt() ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) with open(os.path.join(ROOT, "README.md"), "r") as f: - readme = md.parse(f.read()) + readme = md.parse(f.read()) # NOTE: Currently, we only have one table in README, which is the Model readme. table = [r for r in readme if r.type == "html_block" and r.content.startswith(" dict[int, str]: - return { - v: status - for v, status in zip( - range(1, 8), - [ - "1 - Planning", - "2 - Pre-Alpha", - "3 - Alpha", - "4 - Beta", - "5 - Production/Stable", - "6 - Mature", - "7 - Inactive", - ], - ) - } + @staticmethod + def status() -> dict[int, str]: + return {v: status for v, status in zip(range(1, 8), ["1 - Planning", "2 - Pre-Alpha", "3 - Alpha", "4 - Beta", "5 - Production/Stable", "6 - Mature", "7 - Inactive",],)} - @staticmethod - def apache() -> str: - return Classifier.create_classifier("license", "OSI Approved", "Apache Software License") + @staticmethod + def apache() -> str: + return Classifier.create_classifier("license", "OSI Approved", "Apache Software License") - @staticmethod - def create_classifier(identifier: str, *decls: t.Any) -> str: - cls_ = Classifier() - if identifier not in cls_.identifier: - raise ValueError(f"{identifier} is not yet supported (supported alias: {Classifier.identifier})") - return cls_.joiner.join([cls_.identifier[identifier], *decls]) + @staticmethod + def create_classifier(identifier: str, *decls: t.Any) -> str: + cls_ = Classifier() + if identifier not in cls_.identifier: + raise ValueError(f"{identifier} is not yet supported (supported alias: {Classifier.identifier})") + return cls_.joiner.join([cls_.identifier[identifier], *decls]) - @staticmethod - def create_python_classifier( - implementation: list[str] | None = None, supported_version: list[str] | None = None - ) -> list[str]: - if supported_version is None: - supported_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] - if implementation is None: - implementation = ["CPython", "PyPy"] - base = [ - Classifier.create_classifier("language", "Python"), - Classifier.create_classifier("language", "Python", "3"), - ] - base.append(Classifier.create_classifier("language", "Python", "3", "Only")) - base.extend([Classifier.create_classifier("language", "Python", version) for version in supported_version]) - base.extend( - [Classifier.create_classifier("language", "Python", "Implementation", impl) for impl in implementation] - ) - return base - - @staticmethod - def create_status_classifier(level: int) -> str: - return Classifier.create_classifier("status", Classifier.status()[level]) + @staticmethod + def create_python_classifier(implementation: list[str] | None = None, supported_version: list[str] | None = None) -> list[str]: + if supported_version is None: + supported_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] + if implementation is None: + implementation = ["CPython", "PyPy"] + base = [Classifier.create_classifier("language", "Python"), Classifier.create_classifier("language", "Python", "3"),] + base.append(Classifier.create_classifier("language", "Python", "3", "Only")) + base.extend([Classifier.create_classifier("language", "Python", version) for version in supported_version]) + base.extend([Classifier.create_classifier("language", "Python", "Implementation", impl) for impl in implementation]) + return base + @staticmethod + def create_status_classifier(level: int) -> str: + return Classifier.create_classifier("status", Classifier.status()[level]) @dataclasses.dataclass(frozen=True) class Dependencies: - name: str - git_repo_url: t.Optional[str] = None - branch: t.Optional[str] = None - extensions: t.Optional[t.List[str]] = None - subdirectory: t.Optional[str] = None - requires_gpu: bool = False - lower_constraint: t.Optional[str] = None - upper_constraint: t.Optional[str] = None - platform: t.Optional[t.Tuple[t.Literal["Linux", "Windows", "Darwin"], t.Literal["eq", "ne"]]] = None + name: str + git_repo_url: t.Optional[str] = None + branch: t.Optional[str] = None + extensions: t.Optional[t.List[str]] = None + subdirectory: t.Optional[str] = None + requires_gpu: bool = False + lower_constraint: t.Optional[str] = None + upper_constraint: t.Optional[str] = None + platform: t.Optional[t.Tuple[t.Literal["Linux", "Windows", "Darwin"], t.Literal["eq", "ne"]]] = None - def with_options(self, **kwargs: t.Any) -> Dependencies: - return dataclasses.replace(self, **kwargs) + def with_options(self, **kwargs: t.Any) -> Dependencies: + return dataclasses.replace(self, **kwargs) - @property - def has_constraint(self) -> bool: - return self.lower_constraint is not None or self.upper_constraint is not None + @property + def has_constraint(self) -> bool: + return self.lower_constraint is not None or self.upper_constraint is not None - @property - def pypi_extensions(self) -> str: - return "" if self.extensions is None else f"[{','.join(self.extensions)}]" + @property + def pypi_extensions(self) -> str: + return "" if self.extensions is None else f"[{','.join(self.extensions)}]" - @staticmethod - def platform_restriction(platform: t.LiteralString, op: t.Literal["eq", "ne"] = "eq") -> str: - return f'platform_system{"==" if op == "eq" else "!="}"{platform}"' + @staticmethod + def platform_restriction(platform: t.LiteralString, op: t.Literal["eq", "ne"] = "eq") -> str: + return f'platform_system{"==" if op == "eq" else "!="}"{platform}"' - def to_str(self) -> str: - deps: list[str] = [] - if self.lower_constraint is not None and self.upper_constraint is not None: - dep = f"{self.name}{self.pypi_extensions}>={self.lower_constraint},<{self.upper_constraint}" - elif self.lower_constraint is not None: - dep = f"{self.name}{self.pypi_extensions}>={self.lower_constraint}" - elif self.upper_constraint is not None: - dep = f"{self.name}{self.pypi_extensions}<{self.upper_constraint}" - elif self.subdirectory is not None: - dep = f"{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git#subdirectory={self.subdirectory}" - elif self.branch is not None: - dep = f"{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git@{self.branch}" - else: - dep = f"{self.name}{self.pypi_extensions}" + def to_str(self) -> str: + deps: list[str] = [] + if self.lower_constraint is not None and self.upper_constraint is not None: + dep = f"{self.name}{self.pypi_extensions}>={self.lower_constraint},<{self.upper_constraint}" + elif self.lower_constraint is not None: + dep = f"{self.name}{self.pypi_extensions}>={self.lower_constraint}" + elif self.upper_constraint is not None: + dep = f"{self.name}{self.pypi_extensions}<{self.upper_constraint}" + elif self.subdirectory is not None: + dep = f"{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git#subdirectory={self.subdirectory}" + elif self.branch is not None: + dep = f"{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git@{self.branch}" + else: + dep = f"{self.name}{self.pypi_extensions}" - deps.append(dep) + deps.append(dep) - if self.platform: - deps.append(self.platform_restriction(*self.platform)) + if self.platform: + deps.append(self.platform_restriction(*self.platform)) - return ";".join(deps) - - @classmethod - def from_tuple(cls, *decls: t.Any) -> Dependencies: - return cls(*decls) + return ";".join(deps) + @classmethod + def from_tuple(cls, *decls: t.Any) -> Dependencies: + return cls(*decls) _BENTOML_EXT = ["grpc", "io"] _TRANSFORMERS_EXT = ["torch", "tokenizers", "accelerate"] @@ -179,14 +142,8 @@ _BASE_DEPENDENCIES = [ ] _NIGHTLY_MAPPING: dict[str, Dependencies] = { - "bentoml": Dependencies.from_tuple("bentoml", "bentoml/bentoml", "main", _BENTOML_EXT), - "peft": Dependencies.from_tuple("peft", "huggingface/peft", "main", None), - "transformers": Dependencies.from_tuple("transformers", "huggingface/transformers", "main", _TRANSFORMERS_EXT), - "optimum": Dependencies.from_tuple("optimum", "huggingface/optimum", "main", None), - "accelerate": Dependencies.from_tuple("accelerate", "huggingface/accelerate", "main", None), - "bitsandbytes": Dependencies.from_tuple("bitsandbytes", "TimDettmers/bitsandbytes", "main", None), - "trl": Dependencies.from_tuple("trl", "lvwerra/trl", "main", None), - "vllm": Dependencies.from_tuple("vllm", "vllm-project/vllm", "main", None, None, True, None), + "bentoml": Dependencies.from_tuple("bentoml", "bentoml/bentoml", "main", _BENTOML_EXT), "peft": Dependencies.from_tuple("peft", "huggingface/peft", "main", None), "transformers": Dependencies.from_tuple("transformers", "huggingface/transformers", "main", _TRANSFORMERS_EXT), "optimum": Dependencies.from_tuple("optimum", "huggingface/optimum", "main", None), + "accelerate": Dependencies.from_tuple("accelerate", "huggingface/accelerate", "main", None), "bitsandbytes": Dependencies.from_tuple("bitsandbytes", "TimDettmers/bitsandbytes", "main", None), "trl": Dependencies.from_tuple("trl", "lvwerra/trl", "main", None), "vllm": Dependencies.from_tuple("vllm", "vllm-project/vllm", "main", None, None, True, None), } _ALL_RUNTIME_DEPS = ["flax", "jax", "jaxlib", "tensorflow", "keras"] @@ -200,114 +157,91 @@ GGML_DEPS = ["ctransformers"] GPTQ_DEPS = ["auto-gptq[triton]"] VLLM_DEPS = ["vllm", "ray"] -_base_requirements: dict[str, t.Any]= { - inflection.dasherize(name): config_cls.__openllm_requirements__ - for name, config_cls in openllm.CONFIG_MAPPING.items() - if config_cls.__openllm_requirements__ -} +_base_requirements: dict[str, t.Any] = {inflection.dasherize(name): config_cls.__openllm_requirements__ for name, config_cls in openllm.CONFIG_MAPPING.items() if config_cls.__openllm_requirements__} # shallow copy from locals() _locals = locals().copy() # NOTE: update this table when adding new external dependencies # sync with openllm.utils.OPTIONAL_DEPENDENCIES -_base_requirements.update( - {v: _locals.get(f"{inflection.underscore(v).upper()}_DEPS") for v in openllm.utils.OPTIONAL_DEPENDENCIES} -) +_base_requirements.update({v: _locals.get(f"{inflection.underscore(v).upper()}_DEPS") for v in openllm.utils.OPTIONAL_DEPENDENCIES}) _base_requirements = {k: v for k, v in sorted(_base_requirements.items())} fname = f"{os.path.basename(os.path.dirname(__file__))}/{os.path.basename(__file__)}" - def create_classifiers() -> Array: - arr = tomlkit.array() - arr.extend( - [ - Classifier.create_status_classifier(5), - Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA"), - Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "12"), - Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "11.8"), - Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "11.7"), - Classifier.apache(), - Classifier.create_classifier("topic", "Scientific/Engineering", "Artificial Intelligence"), - Classifier.create_classifier("topic", "Software Development", "Libraries"), - Classifier.create_classifier("os", "OS Independent"), - Classifier.create_classifier("audience", "Developers"), - Classifier.create_classifier("audience", "Science/Research"), - Classifier.create_classifier("audience", "System Administrators"), - Classifier.create_classifier("typing", "Typed"), - *Classifier.create_python_classifier(), - ] - ) - return arr.multiline(True) - + arr = tomlkit.array() + arr.extend([ + Classifier.create_status_classifier(5), + Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA"), + Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "12"), + Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "11.8"), + Classifier.create_classifier("environment", "GPU", "NVIDIA CUDA", "11.7"), + Classifier.apache(), + Classifier.create_classifier("topic", "Scientific/Engineering", "Artificial Intelligence"), + Classifier.create_classifier("topic", "Software Development", "Libraries"), + Classifier.create_classifier("os", "OS Independent"), + Classifier.create_classifier("audience", "Developers"), + Classifier.create_classifier("audience", "Science/Research"), + Classifier.create_classifier("audience", "System Administrators"), + Classifier.create_classifier("typing", "Typed"), *Classifier.create_python_classifier(), + ]) + return arr.multiline(True) def create_optional_table() -> Table: - all_array = tomlkit.array() - all_array.extend([f"openllm[{k}]" for k in _base_requirements]) + all_array = tomlkit.array() + all_array.extend([f"openllm[{k}]" for k in _base_requirements]) - table = tomlkit.table(is_super_table=True) - _base_requirements.update({"all": all_array.multiline(True)}) - table.update({k: v for k, v in sorted(_base_requirements.items())}) - table.add(tomlkit.nl()) - - return table + table = tomlkit.table(is_super_table=True) + _base_requirements.update({"all": all_array.multiline(True)}) + table.update({k: v for k, v in sorted(_base_requirements.items())}) + table.add(tomlkit.nl()) + return table def create_url_table() -> Table: - table = tomlkit.table() - _urls = { - "Blog": "https://modelserving.com", - "Chat": "https://discord.gg/openllm", - "Documentation": "https://github.com/bentoml/openllm#readme", - "GitHub": "https://github.com/bentoml/openllm", - "History": "https://github.com/bentoml/openllm/blob/main/CHANGELOG.md", - "Homepage": "https://bentoml.com", - "Tracker": "https://github.com/bentoml/openllm/issues", - "Twitter": "https://twitter.com/bentomlai", - } - table.update({k: v for k, v in sorted(_urls.items())}) - return table + table = tomlkit.table() + _urls = { + "Blog": "https://modelserving.com", "Chat": "https://discord.gg/openllm", "Documentation": "https://github.com/bentoml/openllm#readme", "GitHub": "https://github.com/bentoml/openllm", "History": "https://github.com/bentoml/openllm/blob/main/CHANGELOG.md", "Homepage": "https://bentoml.com", "Tracker": "https://github.com/bentoml/openllm/issues", + "Twitter": "https://twitter.com/bentomlai", + } + table.update({k: v for k, v in sorted(_urls.items())}) + return table def build_cli_extensions() -> Table: - table = tomlkit.table() - ext: dict[str, str] = {"openllm": "openllm.cli.entrypoint:cli"} - ext.update({f"openllm-{inflection.dasherize(ke)}": f"openllm.cli.ext.{ke}:cli" for ke in sorted([fname[:-3] - for fname in os.listdir(os.path.abspath(os.path.join(ROOT, "src", "openllm", "cli", "ext"))) - if fname.endswith(".py") and not fname.startswith("__")])}) - table.update(ext) - return table + table = tomlkit.table() + ext: dict[str, str] = {"openllm": "openllm.cli.entrypoint:cli"} + ext.update({f"openllm-{inflection.dasherize(ke)}": f"openllm.cli.ext.{ke}:cli" for ke in sorted([fname[:-3] for fname in os.listdir(os.path.abspath(os.path.join(ROOT, "src", "openllm", "cli", "ext"))) if fname.endswith(".py") and not fname.startswith("__")])}) + table.update(ext) + return table def main() -> int: - with open(os.path.join(ROOT, "pyproject.toml"), "r") as f: - pyproject = tomlkit.parse(f.read()) + with open(os.path.join(ROOT, "pyproject.toml"), "r") as f: + pyproject = tomlkit.parse(f.read()) - dependencies_array = tomlkit.array() - dependencies_array.extend([v.to_str() for v in _BASE_DEPENDENCIES]) + dependencies_array = tomlkit.array() + dependencies_array.extend([v.to_str() for v in _BASE_DEPENDENCIES]) - pyproject["project"]["urls"] = create_url_table() - pyproject["project"]["scripts"] = build_cli_extensions() - pyproject["project"]["classifiers"] = create_classifiers() - pyproject["project"]["optional-dependencies"] = create_optional_table() - pyproject["project"]["dependencies"] = dependencies_array.multiline(True) + pyproject["project"]["urls"] = create_url_table() + pyproject["project"]["scripts"] = build_cli_extensions() + pyproject["project"]["classifiers"] = create_classifiers() + pyproject["project"]["optional-dependencies"] = create_optional_table() + pyproject["project"]["dependencies"] = dependencies_array.multiline(True) - with open(os.path.join(ROOT, "pyproject.toml"), "w") as f: - f.write(tomlkit.dumps(pyproject)) + with open(os.path.join(ROOT, "pyproject.toml"), "w") as f: + f.write(tomlkit.dumps(pyproject)) - with open(os.path.join(ROOT, "nightly-requirements.txt"), "w") as f: - f.write(f"# This file is generated by `{fname}`. DO NOT EDIT\n-e .[playground,flan-t5]\n") - f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values() if not v.requires_gpu]) - with open(os.path.join(ROOT, "nightly-requirements-gpu.txt"), "w") as f: - f.write(f"# This file is generated by `{fname}`. # DO NOT EDIT\n") - f.write( - "# For Jax, Flax, Tensorflow, PyTorch CUDA support, please refers to their official installation for your specific setup.\n" - ) - f.write("-r nightly-requirements.txt\n-e .[all]\n") - f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values() if v.requires_gpu]) - - return 0 + with open(os.path.join(ROOT, "nightly-requirements.txt"), "w") as f: + f.write(f"# This file is generated by `{fname}`. DO NOT EDIT\n-e .[playground,flan-t5]\n") + f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values() if not v.requires_gpu]) + with open(os.path.join(ROOT, "nightly-requirements-gpu.txt"), "w") as f: + f.write(f"# This file is generated by `{fname}`. # DO NOT EDIT\n") + f.write("# For Jax, Flax, Tensorflow, PyTorch CUDA support, please refers to their official installation for your specific setup.\n") + f.write("-r nightly-requirements.txt\n-e .[all]\n") + f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values() if v.requires_gpu]) + return 0 if __name__ == "__main__": - raise SystemExit(main()) + raise SystemExit(main()) diff --git a/tools/generate-coverage.py b/tools/generate-coverage.py index 12b21cdd..7246a4ad 100755 --- a/tools/generate-coverage.py +++ b/tools/generate-coverage.py @@ -21,51 +21,48 @@ from pathlib import Path import orjson from lxml import etree - ROOT = Path(__file__).resolve().parent.parent PACKAGES = {"src/openllm/": "openllm"} - def main() -> int: - coverage_report = ROOT / "coverage.xml" - root = etree.fromstring(coverage_report.read_text()) + coverage_report = ROOT / "coverage.xml" + root = etree.fromstring(coverage_report.read_text()) - raw_package_data: defaultdict[str, dict[str, int]] = defaultdict(lambda: {"hits": 0, "misses": 0}) - for package in root.find("packages"): - for module in package.find("classes"): - filename = module.attrib["filename"] - for relative_path, package_name in PACKAGES.items(): - if filename.startswith(relative_path): - data = raw_package_data[package_name] - break - else: - message = f"unknown package: {module}" - raise ValueError(message) + raw_package_data: defaultdict[str, dict[str, int]] = defaultdict(lambda: {"hits": 0, "misses": 0}) + for package in root.find("packages"): + for module in package.find("classes"): + filename = module.attrib["filename"] + for relative_path, package_name in PACKAGES.items(): + if filename.startswith(relative_path): + data = raw_package_data[package_name] + break + else: + message = f"unknown package: {module}" + raise ValueError(message) - for line in module.find("lines"): - if line.attrib["hits"] == "1": - data["hits"] += 1 - else: - data["misses"] += 1 + for line in module.find("lines"): + if line.attrib["hits"] == "1": + data["hits"] += 1 + else: + data["misses"] += 1 - total_statements_covered = 0 - total_statements = 0 - coverage_data = {} - for package_name, data in sorted(raw_package_data.items()): - statements_covered = data["hits"] - statements = statements_covered + data["misses"] - total_statements_covered += statements_covered - total_statements += statements + total_statements_covered = 0 + total_statements = 0 + coverage_data = {} + for package_name, data in sorted(raw_package_data.items()): + statements_covered = data["hits"] + statements = statements_covered + data["misses"] + total_statements_covered += statements_covered + total_statements += statements - coverage_data[package_name] = {"statements_covered": statements_covered, "statements": statements} - coverage_data["total"] = {"statements_covered": total_statements_covered, "statements": total_statements} + coverage_data[package_name] = {"statements_covered": statements_covered, "statements": statements} + coverage_data["total"] = {"statements_covered": total_statements_covered, "statements": total_statements} - coverage_summary = ROOT / "coverage-summary.json" - coverage_summary.write_text(orjson.dumps(coverage_data, option=orjson.OPT_INDENT_2).decode(), encoding="utf-8") - - return 0 + coverage_summary = ROOT / "coverage-summary.json" + coverage_summary.write_text(orjson.dumps(coverage_data, option=orjson.OPT_INDENT_2).decode(), encoding="utf-8") + return 0 if __name__ == "__main__": - raise SystemExit(main()) + raise SystemExit(main()) diff --git a/tools/update-config-stubs.py b/tools/update-config-stubs.py index fde63612..a705578d 100755 --- a/tools/update-config-stubs.py +++ b/tools/update-config-stubs.py @@ -25,8 +25,7 @@ from openllm._configuration import GenerationConfig from openllm._configuration import ModelSettings from openllm._configuration import PeftType - -# currently we are assuming the indentatio level is 4 for comments +# currently we are assuming the indentatio level is 2 for comments START_COMMENT = f"# {os.path.basename(__file__)}: start\n" END_COMMENT = f"# {os.path.basename(__file__)}: stop\n" START_SPECIAL_COMMENT = f"# {os.path.basename(__file__)}: special start\n" @@ -38,28 +37,26 @@ _TARGET_FILE = Path(__file__).parent.parent / "src" / "openllm" / "_configuratio _imported = importlib.import_module(ModelSettings.__module__) def process_annotations(annotations: str) -> str: - if "NotRequired" in annotations: - return annotations[len("NotRequired[") : -1] - elif "Required" in annotations: - return annotations[len("Required[") : -1] - else: - return annotations + if "NotRequired" in annotations: + return annotations[len("NotRequired["):-1] + elif "Required" in annotations: + return annotations[len("Required["):-1] + else: + return annotations _value_docstring = { "default_id": """Return the default model to use when using 'openllm start '. This could be one of the keys in 'self.model_ids' or custom users model. This field is required when defining under '__config__'. - """, - "model_ids": """A list of supported pretrained models tag for this given runnable. + """, "model_ids": """A list of supported pretrained models tag for this given runnable. For example: For FLAN-T5 impl, this would be ["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"] This field is required when defining under '__config__'. - """, - "architecture": """The model architecture that is supported by this LLM. + """, "architecture": """The model architecture that is supported by this LLM. Note that any model weights within this architecture generation can always be run and supported by this LLM. @@ -68,29 +65,16 @@ _value_docstring = { ```bash openllm start gpt-neox --model-id stabilityai/stablelm-tuned-alpha-3b - ```""", - "default_implementation": """The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`. + ```""", "default_implementation": """The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`. - It is a dictionary of key as the accelerator spec in k8s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM Runtime ('flax', 'tf', 'pt', 'vllm') - """, - "url": """The resolved url for this LLMConfig.""", - "requires_gpu": """Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.""", - "trust_remote_code": """Whether to always trust remote code""", - "service_name": """Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""", + It is a dictionary of key as the accelerator spec in k4s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM Runtime ('flax', 'tf', 'pt', 'vllm') + """, "url": """The resolved url for this LLMConfig.""", "requires_gpu": """Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.""", "trust_remote_code": """Whether to always trust remote code""", "service_name": """Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""", "requirements": """The default PyPI requirements needed to run this given LLM. By default, we will depend on - bentoml, torch, transformers.""", - "bettertransformer": """Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.""", - "model_type": """The model type for this given LLM. By default, it should be causal language modeling. + bentoml, torch, transformers.""", "bettertransformer": """Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.""", "model_type": """The model type for this given LLM. By default, it should be causal language modeling. Currently supported 'causal_lm' or 'seq2seq_lm' - """, - "runtime": """The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.""", - "name_type": """The default name typed for this model. "dasherize" will convert the name to lowercase and + """, "runtime": """The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.""", "name_type": """The default name typed for this model. "dasherize" will convert the name to lowercase and replace spaces with dashes. "lowercase" will convert the name to lowercase. If this is not set, then both - `model_name` and `start_name` must be specified.""", - "model_name": """The normalized version of __openllm_start_name__, determined by __openllm_name_type__""", - "start_name": """Default name to be used with `openllm start`""", - "env": """A EnvVarMixin instance for this LLMConfig.""", - "timeout": """The default timeout to be set for this given LLM.""", + `model_name` and `start_name` must be specified.""", "model_name": """The normalized version of __openllm_start_name__, determined by __openllm_name_type__""", "start_name": """Default name to be used with `openllm start`""", "env": """A EnvVarMixin instance for this LLMConfig.""", "timeout": """The default timeout to be set for this given LLM.""", "workers_per_resource": """The number of workers per resource. This is used to determine the number of workers to use for this model. For example, if this is set to 0.5, then OpenLLM will use 1 worker per 2 resources. If this is set to 1, then OpenLLM will use 1 worker per resource. If this is set to 2, then OpenLLM will use 2 workers per resource. @@ -99,106 +83,58 @@ _value_docstring = { https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy for more details. By default, it is set to 1. - """, - "fine_tune_strategies": """The fine-tune strategies for this given LLM.""", - "tokenizer_class": """Optional tokenizer class for this given LLM. See Llama for example.""", + """, "fine_tune_strategies": """The fine-tune strategies for this given LLM.""", "tokenizer_class": """Optional tokenizer class for this given LLM. See Llama for example.""", } _transformed = {"fine_tune_strategies": "t.Dict[AdapterType, FineTuneConfig]"} - def main() -> int: - with _TARGET_FILE.open("r") as f: - processed = f.readlines() + with _TARGET_FILE.open("r") as f: + processed = f.readlines() - start_idx, end_idx = processed.index(" " * 4 + START_COMMENT), processed.index(" " * 4 + END_COMMENT) - start_stub_idx, end_stub_idx = processed.index(" " * 8 + START_SPECIAL_COMMENT), processed.index(" " * 8 + END_SPECIAL_COMMENT) - start_attrs_idx, end_attrs_idx = processed.index(" " * 8 + START_ATTRS_COMMENT), processed.index(" " * 8 + END_ATTRS_COMMENT) + start_idx, end_idx = processed.index(" "*2 + START_COMMENT), processed.index(" "*2 + END_COMMENT) + start_stub_idx, end_stub_idx = processed.index(" "*4 + START_SPECIAL_COMMENT), processed.index(" "*4 + END_SPECIAL_COMMENT) + start_attrs_idx, end_attrs_idx = processed.index(" "*4 + START_ATTRS_COMMENT), processed.index(" "*4 + END_ATTRS_COMMENT) - # NOTE: inline stubs __config__ attrs representation - special_attrs_lines: list[str] = [] - for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items(): special_attrs_lines.append(f"{' ' * 8}{keys}: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}\n") + # NOTE: inline stubs __config__ attrs representation + special_attrs_lines: list[str] = [] + for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items(): + special_attrs_lines.append(f"{' ' * 4}{keys}: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}\n") - # NOTE: inline stubs for _ConfigAttr type stubs - config_attr_lines: list[str] = [] - for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items(): - config_attr_lines.extend( - [ - " " * 8 + line - for line in [ - f"__openllm_{keys}__: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))} = Field(None)\n", - f'"""{_value_docstring[keys]}"""\n', - ] - ] - ) + # NOTE: inline stubs for _ConfigAttr type stubs + config_attr_lines: list[str] = [] + for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items(): + config_attr_lines.extend([" "*4 + line for line in [f"__openllm_{keys}__: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))} = Field(None)\n", f'"""{_value_docstring[keys]}"""\n',]]) - # NOTE: inline runtime __getitem__ overload process - lines: list[str] = [] - lines.append(" " * 4 + "# NOTE: ModelSettings arguments\n") - for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items(): - lines.extend( - [ - " " * 4 + line - for line in [ - "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", - f'def __getitem__(self, item: t.Literal["{keys}"]) -> {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n', - ] - ] - ) - # special case variables: generation_class, extras, sampling_class - lines.append(" " * 4 + "# NOTE: generation_class, sampling_class and extras arguments\n") - lines.extend( - [ - " " * 4 + line - for line in [ - "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", - 'def __getitem__(self, item: t.Literal["generation_class"]) -> t.Type[openllm._configuration.GenerationConfig]: ...\n', - "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", - 'def __getitem__(self, item: t.Literal["sampling_class"]) -> t.Type[openllm._configuration.SamplingParams]: ...\n', - "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", - 'def __getitem__(self, item: t.Literal["extras"]) -> t.Dict[str, t.Any]: ...\n', - ] - ] - ) - lines.append(" " * 4 + "# NOTE: GenerationConfig arguments\n") - generation_config_anns = openllm.utils.codegen.get_annotations(GenerationConfig) - for keys, type_pep563 in generation_config_anns.items(): - lines.extend( - [ - " " * 4 + line - for line in [ - "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", - f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n', - ] - ] - ) - lines.append(" " * 4 + "# NOTE: SamplingParams arguments\n") - for keys, type_pep563 in openllm.utils.codegen.get_annotations(SamplingParams).items(): - if keys not in generation_config_anns: - lines.extend( - [ - " " * 4 + line - for line in [ - "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", - f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n', - ] - ] - ) + # NOTE: inline runtime __getitem__ overload process + lines: list[str] = [] + lines.append(" "*2 + "# NOTE: ModelSettings arguments\n") + for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items(): + lines.extend([" "*2 + line for line in ["@overload\n" if "overload" in dir(_imported) else "@t.overload\n", f'def __getitem__(self, item: t.Literal["{keys}"]) -> {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n',]]) + # special case variables: generation_class, extras, sampling_class + lines.append(" "*2 + "# NOTE: generation_class, sampling_class and extras arguments\n") + lines.extend([ + " "*2 + line for line in [ + "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", 'def __getitem__(self, item: t.Literal["generation_class"]) -> t.Type[openllm._configuration.GenerationConfig]: ...\n', "@overload\n" + if "overload" in dir(_imported) else "@t.overload\n", 'def __getitem__(self, item: t.Literal["sampling_class"]) -> t.Type[openllm._configuration.SamplingParams]: ...\n', "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", 'def __getitem__(self, item: t.Literal["extras"]) -> t.Dict[str, t.Any]: ...\n', + ] + ]) + lines.append(" "*2 + "# NOTE: GenerationConfig arguments\n") + generation_config_anns = openllm.utils.codegen.get_annotations(GenerationConfig) + for keys, type_pep563 in generation_config_anns.items(): + lines.extend([" "*2 + line for line in ["@overload\n" if "overload" in dir(_imported) else "@t.overload\n", f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n',]]) + lines.append(" "*2 + "# NOTE: SamplingParams arguments\n") + for keys, type_pep563 in openllm.utils.codegen.get_annotations(SamplingParams).items(): + if keys not in generation_config_anns: + lines.extend([" "*2 + line for line in ["@overload\n" if "overload" in dir(_imported) else "@t.overload\n", f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n',]]) - lines.append(" " * 4 + "# NOTE: PeftType arguments\n") - for keys in PeftType._member_names_: - lines.extend( - [ - " " * 4 + line - for line in [ - "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", - f'def __getitem__(self, item: t.Literal["{keys.lower()}"]) -> dict[str, t.Any]: ...\n', - ] - ] - ) + lines.append(" "*2 + "# NOTE: PeftType arguments\n") + for keys in PeftType._member_names_: + lines.extend([" "*2 + line for line in ["@overload\n" if "overload" in dir(_imported) else "@t.overload\n", f'def __getitem__(self, item: t.Literal["{keys.lower()}"]) -> dict[str, t.Any]: ...\n',]]) - processed = processed[:start_attrs_idx] + [" " * 8 + START_ATTRS_COMMENT, *special_attrs_lines, " " * 8 + END_ATTRS_COMMENT] + processed[end_attrs_idx + 1 : start_stub_idx] + [" " * 8 + START_SPECIAL_COMMENT, *config_attr_lines, " " * 8 + END_SPECIAL_COMMENT] + processed[end_stub_idx + 1 : start_idx] + [" " * 4 + START_COMMENT, *lines, " " * 4 + END_COMMENT] + processed[end_idx + 1 :] - with _TARGET_FILE.open("w") as f: f.writelines(processed) - return 0 + processed = processed[:start_attrs_idx] + [" "*4 + START_ATTRS_COMMENT, *special_attrs_lines, " "*4 + END_ATTRS_COMMENT] + processed[end_attrs_idx + 1:start_stub_idx] + [" "*4 + START_SPECIAL_COMMENT, *config_attr_lines, " "*4 + END_SPECIAL_COMMENT] + processed[end_stub_idx + 1:start_idx] + [" "*2 + START_COMMENT, *lines, " "*2 + END_COMMENT] + processed[end_idx + 1:] + with _TARGET_FILE.open("w") as f: + f.writelines(processed) + return 0 if __name__ == "__main__": raise SystemExit(main()) diff --git a/tools/update-models-import.py b/tools/update-models-import.py index 1e95ee05..2e9a8f60 100755 --- a/tools/update-models-import.py +++ b/tools/update-models-import.py @@ -20,25 +20,30 @@ import openllm _TARGET_FILE = Path(__file__).parent.parent / "src" / "openllm" / "models" / "__init__.py" -def comment_generator(comment_type: str, action: t.Literal["start", "stop"] = "start", indentation: int = 0) -> str: return " " * indentation + f"# {os.path.basename(__file__)}: {action} {comment_type}\n" +def comment_generator(comment_type: str, action: t.Literal["start", "stop"] = "start", indentation: int = 0) -> str: + return " "*indentation + f"# {os.path.basename(__file__)}: {action} {comment_type}\n" START_MODULE_COMMENT, STOP_MODULE_COMMENT = comment_generator("module"), comment_generator("module", "stop") -START_TYPES_COMMENT, STOP_TYPES_COMMENT = comment_generator("types", indentation=4), comment_generator("types", "stop", indentation=4) +START_TYPES_COMMENT, STOP_TYPES_COMMENT = comment_generator("types", indentation=2), comment_generator("types", "stop", indentation=2) -@openllm.utils.apply(lambda v: sorted([" " * 4 + _ for _ in v], key=lambda k: k.split()[-1])) -def create_stubs_import() -> list[str]: return [f"from . import {p.name} as {p.name}\n" for p in _TARGET_FILE.parent.glob("*/") if p.name not in {"__pycache__", "__init__.py", ".DS_Store"}] -def create_module_import() -> str: return f"_MODELS: set[str] = {{{', '.join(sorted([repr(p.name) for p in _TARGET_FILE.parent.glob('*/') if p.name not in ['__pycache__', '__init__.py', '.DS_Store']]))}}}\n" +@openllm.utils.apply(lambda v: sorted([" "*2 + _ for _ in v], key=lambda k: k.split()[-1])) +def create_stubs_import() -> list[str]: + return [f"from . import {p.name} as {p.name}\n" for p in _TARGET_FILE.parent.glob("*/") if p.name not in {"__pycache__", "__init__.py", ".DS_Store"}] + +def create_module_import() -> str: + return f"_MODELS: set[str] = {{{', '.join(sorted([repr(p.name) for p in _TARGET_FILE.parent.glob('*/') if p.name not in ['__pycache__', '__init__.py', '.DS_Store']]))}}}\n" def main() -> int: - with _TARGET_FILE.open("r") as f: processed = f.readlines() - stubs_lines, module_line = create_stubs_import(), create_module_import() - - start_module_idx, stop_module_idx = processed.index(START_MODULE_COMMENT), processed.index(STOP_MODULE_COMMENT) - start_types_idx, stop_types_idex = processed.index(START_TYPES_COMMENT), processed.index(STOP_TYPES_COMMENT) - processed = processed[:start_module_idx] + [START_MODULE_COMMENT, module_line, STOP_MODULE_COMMENT] + processed[stop_module_idx+1:start_types_idx] + [START_TYPES_COMMENT, *stubs_lines, STOP_TYPES_COMMENT] + processed[stop_types_idex+1:] - with _TARGET_FILE.open("w") as f: f.writelines(processed) - return 0 + with _TARGET_FILE.open("r") as f: + processed = f.readlines() + stubs_lines, module_line = create_stubs_import(), create_module_import() + start_module_idx, stop_module_idx = processed.index(START_MODULE_COMMENT), processed.index(STOP_MODULE_COMMENT) + start_types_idx, stop_types_idex = processed.index(START_TYPES_COMMENT), processed.index(STOP_TYPES_COMMENT) + processed = processed[:start_module_idx] + [START_MODULE_COMMENT, module_line, STOP_MODULE_COMMENT] + processed[stop_module_idx + 1:start_types_idx] + [START_TYPES_COMMENT, *stubs_lines, STOP_TYPES_COMMENT] + processed[stop_types_idex + 1:] + with _TARGET_FILE.open("w") as f: + f.writelines(processed) + return 0 if __name__ == "__main__": - raise SystemExit(main()) + raise SystemExit(main()) diff --git a/tools/update-readme.py b/tools/update-readme.py index 6c819f82..1a092013 100755 --- a/tools/update-readme.py +++ b/tools/update-readme.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations import os @@ -24,76 +23,62 @@ import tomlkit import openllm - START_COMMENT = f"\n" END_COMMENT = f"\n" ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) def main() -> int: - with open(os.path.join(ROOT, "pyproject.toml"), "r") as f: deps = tomlkit.parse(f.read()).value["project"]["optional-dependencies"] - with open(os.path.join(ROOT, "README.md"), "r") as f: readme = f.readlines() + with open(os.path.join(ROOT, "pyproject.toml"), "r") as f: + deps = tomlkit.parse(f.read()).value["project"]["optional-dependencies"] + with open(os.path.join(ROOT, "README.md"), "r") as f: + readme = f.readlines() - start_index, stop_index = readme.index(START_COMMENT), readme.index(END_COMMENT) - formatted: dict[t.Literal["Model", "Architecture", "URL", "Installation", "Model Ids"], list[str | list[str]]] = { - "Model": [], - "Architecture": [], - "URL": [], - "Model Ids": [], - "Installation": [], - } - max_install_len_div = 0 - for name, config_cls in openllm.CONFIG_MAPPING.items(): - dashed = inflection.dasherize(name) - formatted["Model"].append(dashed) - formatted["Architecture"].append(config_cls.__openllm_architecture__) - formatted["URL"].append(config_cls.__openllm_url__) - formatted["Model Ids"].append(config_cls.__openllm_model_ids__) - if dashed in deps: - instruction = f'```bash\npip install "openllm[{dashed}]"\n```' - else: - instruction = "```bash\npip install openllm\n```" - if len(instruction) > max_install_len_div: - max_install_len_div = len(instruction) - formatted["Installation"].append(instruction) + start_index, stop_index = readme.index(START_COMMENT), readme.index(END_COMMENT) + formatted: dict[t.Literal["Model", "Architecture", "URL", "Installation", "Model Ids"], list[str | list[str]]] = {"Model": [], "Architecture": [], "URL": [], "Model Ids": [], "Installation": [],} + max_install_len_div = 0 + for name, config_cls in openllm.CONFIG_MAPPING.items(): + dashed = inflection.dasherize(name) + formatted["Model"].append(dashed) + formatted["Architecture"].append(config_cls.__openllm_architecture__) + formatted["URL"].append(config_cls.__openllm_url__) + formatted["Model Ids"].append(config_cls.__openllm_model_ids__) + if dashed in deps: + instruction = f'```bash\npip install "openllm[{dashed}]"\n```' + else: + instruction = "```bash\npip install openllm\n```" + if len(instruction) > max_install_len_div: + max_install_len_div = len(instruction) + formatted["Installation"].append(instruction) - meta: list[str] = ["\n", "\n"] + meta: list[str] = ["\n", "
\n"] - # NOTE: headers - meta += ["\n"] - meta.extend([f"\n" for header in formatted.keys() if header not in ("URL",)]) - meta += ["\n"] - # NOTE: rows - for name, architecture, url, model_ids, installation in t.cast(t.Iterable[t.Tuple[str, str, str, t.List[str], str]], zip(*formatted.values())): - meta += "\n" - # configure architecture URL - cfg_cls = openllm.CONFIG_MAPPING[name] - if cfg_cls.__openllm_trust_remote_code__: - arch = f"\n" - else: - model_name = { - "dolly_v2": "gpt_neox", - "stablelm": "gpt_neox", - "starcoder": "gpt_bigcode", - "flan_t5": "t5", - }.get(cfg_cls.__openllm_model_name__, cfg_cls.__openllm_model_name__) - arch = f"\n" - meta.extend( - [ - f"\n\n", - arch, - ] - ) - format_with_links: list[str] = [] - for lid in model_ids: - format_with_links.append(f"
  • {lid}
  • ") - meta.append("\n") - meta.append(f"\n") - meta += "\n" - meta.extend(["
    {header}
    {architecture}{architecture}{name}\n\n
      " + "\n".join(format_with_links) + "
    \n\n
    \n\n{installation}\n\n
    \n", "\n"]) + # NOTE: headers + meta += ["\n"] + meta.extend([f"{header}\n" for header in formatted.keys() if header not in ("URL",)]) + meta += ["\n"] + # NOTE: rows + for name, architecture, url, model_ids, installation in t.cast(t.Iterable[t.Tuple[str, str, str, t.List[str], str]], zip(*formatted.values())): + meta += "\n" + # configure architecture URL + cfg_cls = openllm.CONFIG_MAPPING[name] + if cfg_cls.__openllm_trust_remote_code__: + arch = f"{architecture}\n" + else: + model_name = {"dolly_v2": "gpt_neox", "stablelm": "gpt_neox", "starcoder": "gpt_bigcode", "flan_t5": "t5",}.get(cfg_cls.__openllm_model_name__, cfg_cls.__openllm_model_name__) + arch = f"{architecture}\n" + meta.extend([f"\n{name}\n", arch,]) + format_with_links: list[str] = [] + for lid in model_ids: + format_with_links.append(f"
  • {lid}
  • ") + meta.append("\n\n
      " + "\n".join(format_with_links) + "
    \n\n\n") + meta.append(f"\n\n{installation}\n\n\n") + meta += "\n" + meta.extend(["\n", "\n"]) - readme = readme[:start_index] + [START_COMMENT] + meta + [END_COMMENT] + readme[stop_index + 1 :] - with open(os.path.join(ROOT, "README.md"), "w") as f: f.writelines(readme) - return 0 + readme = readme[:start_index] + [START_COMMENT] + meta + [END_COMMENT] + readme[stop_index + 1:] + with open(os.path.join(ROOT, "README.md"), "w") as f: + f.writelines(readme) + return 0 if __name__ == "__main__": raise SystemExit(main()) diff --git a/tools/write-coverage-report.py b/tools/write-coverage-report.py index 0fd55550..7a9d64c2 100755 --- a/tools/write-coverage-report.py +++ b/tools/write-coverage-report.py @@ -21,51 +21,41 @@ from pathlib import Path import orjson - PRECISION = Decimal(".01") ROOT = Path(__file__).resolve().parent.parent - def main(): - coverage_summary = ROOT / "coverage-summary.json" + coverage_summary = ROOT / "coverage-summary.json" - coverage_data = orjson.loads(coverage_summary.read_text(encoding="utf-8")) - total_data = coverage_data.pop("total") + coverage_data = orjson.loads(coverage_summary.read_text(encoding="utf-8")) + total_data = coverage_data.pop("total") - lines = [ - "\n", - "Package | Statements\n", - "------- | ----------\n", - ] + lines = ["\n", "Package | Statements\n", "------- | ----------\n",] - for package, data in sorted(coverage_data.items()): - statements_covered = data["statements_covered"] - statements = data["statements"] + for package, data in sorted(coverage_data.items()): + statements_covered = data["statements_covered"] + statements = data["statements"] - rate = Decimal(statements_covered) / Decimal(statements) * 100 - rate = rate.quantize(PRECISION, rounding=ROUND_DOWN) - lines.append( - f"{package} | {100 if rate == 100 else rate}% ({statements_covered} / {statements})\n" # noqa: PLR2004 - ) + rate = Decimal(statements_covered) / Decimal(statements) * 100 + rate = rate.quantize(PRECISION, rounding=ROUND_DOWN) + lines.append(f"{package} | {100 if rate == 100 else rate}% ({statements_covered} / {statements})\n" # noqa: PLR2004 + ) - total_statements_covered = total_data["statements_covered"] - total_statements = total_data["statements"] - total_rate = Decimal(total_statements_covered) / Decimal(total_statements) * 100 - total_rate = total_rate.quantize(PRECISION, rounding=ROUND_DOWN) - color = "ok" if float(total_rate) >= 95 else "critical" - lines.insert(0, f"![Code Coverage](https://img.shields.io/badge/coverage-{total_rate}%25-{color}?style=flat)\n") + total_statements_covered = total_data["statements_covered"] + total_statements = total_data["statements"] + total_rate = Decimal(total_statements_covered) / Decimal(total_statements) * 100 + total_rate = total_rate.quantize(PRECISION, rounding=ROUND_DOWN) + color = "ok" if float(total_rate) >= 95 else "critical" + lines.insert(0, f"![Code Coverage](https://img.shields.io/badge/coverage-{total_rate}%25-{color}?style=flat)\n") - lines.append( - f"**Summary** | {100 if total_rate == 100 else total_rate}% " - f"({total_statements_covered} / {total_statements})\n" - ) - - coverage_report = ROOT / "coverage-report.md" - with coverage_report.open("w", encoding="utf-8") as f: - f.write("".join(lines)) - return 0 + lines.append(f"**Summary** | {100 if total_rate == 100 else total_rate}% " + f"({total_statements_covered} / {total_statements})\n") + coverage_report = ROOT / "coverage-report.md" + with coverage_report.open("w", encoding="utf-8") as f: + f.write("".join(lines)) + return 0 if __name__ == "__main__": - raise SystemExit(main()) + raise SystemExit(main())