style: define experimental guidelines (#168)

This commit is contained in:
Aaron Pham
2023-07-31 07:54:26 -04:00
committed by GitHub
parent 2c2070f69f
commit 8c2867d26d
128 changed files with 8314 additions and 9472 deletions

View File

@@ -32,6 +32,12 @@ repos:
types: [python]
exclude: ^(docs|tools|tests)
args: [--config=pyproject.toml]
- repo: https://github.com/google/yapf
rev: v0.40.1
hooks:
- id: yapf
types: [python]
args: [--parallel, --recursive]
- repo: local
hooks:
- id: mypy

View File

@@ -155,8 +155,13 @@ hatch run tests:snapshot-models
## Working with Git
To filter out most of the generated commits for infrastructure, use ``--invert-grep`` in conjunction with ``--grep``
to filter out all commits with regex `"[generated]"`
To filter out most of the generated commits for infrastructure, use
`--invert-grep` in conjunction with `--grep` to filter out all commits with
regex `"[generated]"`
## Style
See [STYLE.md](STYLE.md) for our style guide.
## Releasing a New Version

View File

@@ -19,6 +19,8 @@
<img src="https://img.shields.io/pypi/pyversions/openllm.svg?logo=python&label=Python&logoColor=gold" alt="python_version" />
</a><a href="https://github.com/pypa/hatch">
<img src="https://img.shields.io/badge/%F0%9F%A5%9A-Hatch-4051b5.svg" alt="Hatch" />
</a><a href="https://github.com/bentoml/OpenLLM/blob/main/STYLE.md">
<img src="https://img.shields.io/badge/code%20style-experimental-000000.svg" alt="code style" />
</a><a href="https://github.com/astral-sh/ruff">
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json" alt="Ruff" />
</a><a href="https://github.com/python/mypy">

160
STYLE.md Normal file
View File

@@ -0,0 +1,160 @@
## the coding style
This documentation serves as a brief discussion of the coding style used for
OpenLLM. As you have noticed, it is different from the conventional
[PEP8](https://peps.python.org/pep-0008/) style as across many Python projects.
The manifestation of OpenLLM code style is a combination of
[Google Python Style](https://google.github.io/styleguide/pyguide.html),
inspiration from coding language such as APL, Haskell, and is designed for fast,
experimental development and prototyping.
Everyone always has their own opinions on style. I believe this is exemplified
further within the Python community, as it tries to be beginner-friendly, and
therefore most people hold a very strong opinion on styling. I don't have a
strong opinion on style either (I don't have any issue with PEP8, as we use it
for our other projects), as long as:
- You don't use any linter, formatter that change the style drastically other
than what specified within the projects' [`pyproject.toml`](./pyproject.toml).
- The code you contribute is not widely different from the style of the code
surrounding it.
With that being said, I want to use this project as a playground the explore a
style that is both: "feels natural" and expressive for mathematical reasoning. I
hope that you find this guide somewhat thought-provoking and interesting, that
you can iterate and try to adopt some of them as part of the process
contributing to the library.
While PEP8 is a great base for a style guide, I find it to be having way too
much white spaces and makes the code feels 'robotic'. Having a deterministic
style and formatter is great to reduce the overhead of stylistic discussions,
but I think it is important to write code that express the intent of reasoning.
(_The policy here is definitely not "shovel everything into one line", but
rather "compact and flowing"_)
The styling is heavily inspired by
[Kenneth Iverson's](https://en.wikipedia.org/wiki/Kenneth_E._Iverson) 1979
Turing award lecture,
[Notation as a Tool of Thought](https://www.eecg.toronto.edu/~jzhu/csc326/readings/iverson.pdf),
and a lot of the stylistic inspiration comes from
[Jeremy Howard's](https://jeremy.fast.ai/) [fastai](https://docs.fast.ai/). One
thing that has been stuck with me ever since is the idea of "brevity facilitates
reasoning", as such the tersity of style aren't just for the sake of shortness,
rather the brevity of expression. (it enables
[expository programming](http://archive.vector.org.uk/art10000980), combining
with prototyping new ideas and logics within models implementation)
## some guidelines
Though I have stopped using deterministic formatter and linter, I do understand
that people have preferences for using these tools, and it plays nicely with IDE
and editors. As such, I included a [`pyproject.toml`](./pyproject.toml) file
that specifies some configuration for the tools that makes it compiliant with
the repository's style. In short, some of the tools include `ruff`, `yapf`, and
`interrogate`. Since we manage everything via `hatch`, refer back to the
[DEVELOPMENT.md](./DEVELOPMENT.md) for more information on this.
Overtime, Python has incorporated a lot of features that supports this style of
coding, including list comprehension, generator expression, lambda, array-based
programming. Yet, Python will remain verbose per se, and the goal is that to
make code fit nicely on a screen, and we don't have to always scroll downwards.
While brevity is important, it is also important to make sure functions are
somewhat, type-safe. Since there is no real type-safety when working with
Python, typing should be a best-effort to make sure we don't introduce too many
bugs.
### naming
- follow Python standard for this, I don't have too much opinion on this. Just
make sure that it is descriptive, and the abbreviation describes the intent of
the variable. i.e: `to_gpu` instead of `t_gpu`, `to_cpu` instead of `t_cpu`.
- any math-related notation or neural net layers should be expressive and stay
close to the paper as much as possible. For example: `lm_head.weight` instead
of `lm_head.w`. Espically for implementing custom kernels and layers, it is
crucial to follow its nomenclature. E.g: `conv1` instead of
`first_conv_layer`.
_If you have any suggestions, feel free to give it on our discord server!_
### layout
- Preferably not a lot of whitespaces, but rather flowing. If you can fit
everything for `if`, `def` or a `return` within one line, then there's no need
to break it into multiple line:
```python
def foo(x): return rotate_cv(x) if x > 0 else -x
```
- imports should be grouped by their types, and each import should be designated
on its own line.
```python
import os
import sys
```
This is partially to make it easier to work with merge-conflicts, and easier
for IDE to navigate context definition.
- indent with 2 spaces, which follow the Google codestyle.
- With regards to writing operator, try to follow the domain-specific notation.
I.e: when writing pathlib, just don't add space since that is not how you
write a path in the terminal. `yapf` will try to accommodate some of this
changes.
- Avoid trailing whitespace
- use array, pytorch or numpy-based indexing where possible.
- If you need to export anything, put it in `__all__` or do lazy export for
type-safe checker.
### misc
- import alias should be concise and descriptive. A convention is to always
`import typing as t`.
- Writing docstring when it is possible. No need to comment everything asn it
makes the codebase hard to read. For docstring, follow the Google style guide.
- We do lazy imports, so consult some of the `__init__.py` to see how we do it.
- Documentation is still _working-in-progress_, but tldr it will be written in
MDX and will be hosted on the GitHub Pages, so stay tuned!
- If anything that is not used for runtime, just put it under `t.TYPE_CHECKING`
### note on codegen
- We also do some codegen for some of the assignment functions. These logics are
largely based on the work of [attrs](https://github.com/python-attrs/attrs) to
ensure fast and isolated codegen in Python. If you need codegen but don't know
how it works, feel free to mention @aarnphm_ on discord!
## FAQ
### Why not use `black`?
`black` is used on our other projects, but I rather find `black` to be very
verbose and overtime it is annoying to work with too much whitespaces.
### Why not PEP8?
PEP8 is great if you are writing library such as this, but I'm going to do a lot
of experimenting for implementing papers, so I decided early on that PEP8 is
probably not fit here, and want to explore more expressive style.
### Editor is complaining about the style, what should I do?
Kindly ask you to disable linting for this project 🤗. I will try my best to
accomodate with ruff and yapf, but I don't want to spend too much time on this.
It is pretty stragithforward to disable it in your editor, with google.
### Style might put off new contributors?
I don't think so, as mentioned before, I don't have too much opinion on style as
long as it somewhat follow what I have described above or the style of the code
surrounding it. I will still accept styles PR as long as it is not too drastic.
Just make sure to add the revision to `.git-blame-ignore-revs` so that
`git blame` would work correctly.
As for people who are too close-minded about styling, such individuals aren't
the ones we want to work with anyway!

3
changelog.d/168.chore.md Normal file
View File

@@ -0,0 +1,3 @@
Define specific style guideline for the project. See
[STYLE.md](https://github.com/bentoml/OpenLLM/blob/main/STYLE.md) for more
information.

View File

@@ -24,13 +24,11 @@ llm_runner = openllm.Runner(model, llm_config=llm_config)
svc = bentoml.Service(name="llm-service", runners=[llm_runner])
@svc.on_startup
def download(_: bentoml.Context):
llm_runner.download_model()
llm_runner.download_model()
@svc.api(input=bentoml.io.Text(), output=bentoml.io.Text())
async def prompt(input_text: str) -> str:
answer = await llm_runner.generate.async_run(input_text)
return answer[0]["generated_text"]
answer = await llm_runner.generate.async_run(input_text)
return answer[0]["generated_text"]

View File

@@ -25,23 +25,20 @@ from bentoml.io import JSON
from bentoml.io import Text
class Query(BaseModel):
industry: str
product_name: str
keywords: t.List[str]
llm_config: t.Dict[str, t.Any]
industry: str
product_name: str
keywords: t.List[str]
llm_config: t.Dict[str, t.Any]
def gen_llm(model_name: str, model_id: str | None = None) -> OpenLLM:
lc_llm = OpenLLM(model_name=model_name, model_id=model_id, embedded=False)
lc_llm.runner.download_model()
return lc_llm
lc_llm = OpenLLM(model_name=model_name, model_id=model_id, embedded=False)
lc_llm.runner.download_model()
return lc_llm
llm = gen_llm("dolly-v2", model_id="databricks/dolly-v2-7b")
prompt = PromptTemplate(
input_variables=["industry", "product_name", "keywords"],
template="""
input_variables=["industry", "product_name", "keywords"], template="""
You are a Facebook Ads Copywriter with a strong background in persuasive
writing and marketing. You craft compelling copy that appeals to the target
audience's emotions and needs, peruading them to take action or make a
@@ -59,22 +56,12 @@ chain = LLMChain(llm=llm, prompt=prompt)
svc = bentoml.Service("fb-ads-copy", runners=[llm.runner])
@svc.on_startup
def download(_: bentoml.Context):
llm.runner.download_model()
SAMPLE_INPUT = Query(
industry="SAAS",
product_name="BentoML",
keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"],
llm_config=llm.runner.config.model_dump(),
)
llm.runner.download_model()
SAMPLE_INPUT = Query(industry="SAAS", product_name="BentoML", keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"], llm_config=llm.runner.config.model_dump(),)
@svc.api(input=JSON.from_sample(sample=SAMPLE_INPUT), output=Text())
def generate(query: Query):
return chain.run(
{"industry": query.industry, "product_name": query.product_name, "keywords": ", ".join(query.keywords)}
)
return chain.run({"industry": query.industry, "product_name": query.product_name, "keywords": ", ".join(query.keywords)})

View File

@@ -22,16 +22,11 @@ from bentoml.io import Text
SAMPLE_INPUT = "What is the weather in San Francisco?"
llm = OpenLLM(
model_name="dolly-v2",
model_id="databricks/dolly-v2-7b",
embedded=False,
)
llm = OpenLLM(model_name="dolly-v2", model_id="databricks/dolly-v2-7b", embedded=False,)
tools = load_tools(["serpapi"], llm=llm)
agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)
svc = bentoml.Service("langchain-openllm", runners=[llm.runner])
@svc.api(input=Text.from_sample(sample=SAMPLE_INPUT), output=Text())
def chat(input_text: str):
return agent.run(input_text)
return agent.run(input_text)

View File

@@ -205,6 +205,7 @@ ignore = [
"PLR0915",
"PLR2004", # magic value to use constant
"E501", # ignore line length violation
"E702",
"PYI021", # ignore docstring in stubs, as pyright will include docstring in stubs.
"D103", # Just missing docstring for magic methods.
"D102",
@@ -262,6 +263,49 @@ avoid-escape = false
]
"typings/**/*" = ["D", "F", "E", "PYI002"]
[tool.yapf]
ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = true
ALLOW_MULTILINE_DICTIONARY_KEYS = false
ALLOW_MULTILINE_LAMBDAS = false
ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS = false
ALLOW_SPLIT_BEFORE_DICT_VALUE = false
ARITHMETIC_PRECEDENCE_INDICATION = true
BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION = 1
BLANK_LINES_BETWEEN_TOP_LEVEL_IMPORTS_AND_VARIABLES = 1
BLANK_LINE_BEFORE_CLASS_DOCSTRING = false
BLANK_LINE_BEFORE_MODULE_DOCSTRING = false
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = false
COALESCE_BRACKETS = true
COLUMN_LIMIT = 384
CONTINUATION_ALIGN_STYLE = "VALIGN-RIGHT"
DEDENT_CLOSING_BRACKETS = true
DISABLE_ENDING_COMMA_HEURISTIC = true
EACH_DICT_ENTRY_ON_SEPARATE_LINE = false
INDENT_BLANK_LINES = false
INDENT_CLOSING_BRACKETS = false
INDENT_WIDTH = 2
JOIN_MULTIPLE_LINES = true
NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = true
SPACES_AROUND_SUBSCRIPT_COLON = false
SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = false
SPACE_INSIDE_BRACKETS = false
SPLIT_ALL_COMMA_SEPARATED_VALUES = false
SPLIT_ALL_TOP_LEVEL_COMMA_SEPARATED_VALUES = false
SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED = false
SPLIT_BEFORE_BITWISE_OPERATOR = false
SPLIT_BEFORE_CLOSING_BRACKET = false
SPLIT_BEFORE_DICT_SET_GENERATOR = false
SPLIT_BEFORE_DOT = false
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = false
SPLIT_BEFORE_FIRST_ARGUMENT = false
SPLIT_BEFORE_LOGICAL_OPERATOR = false
SPLIT_BEFORE_NAMED_ASSIGNS = false
SPLIT_COMPLEX_COMPREHENSION = false
SPLIT_PENALTY_AFTER_OPENING_BRACKET = 10000
SPLIT_PENALTY_BEFORE_IF_EXPR = 10000
SPLIT_PENALTY_COMPREHENSION = 3000
SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT = 8000
[tool.coverage.paths]
openllm = ["src/openllm", "*/openllm/src/openllm"]
[tool.coverage.run]

View File

@@ -32,244 +32,217 @@ from . import utils as utils
from .exceptions import MissingDependencyError
if utils.DEBUG:
utils.set_debug_mode(True)
utils.set_quiet_mode(False)
logging.basicConfig(level=logging.NOTSET)
utils.set_debug_mode(True)
utils.set_quiet_mode(False)
logging.basicConfig(level=logging.NOTSET)
else:
# configuration for bitsandbytes before import
os.environ["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
# The following warnings from bitsandbytes, and probably not that important
# for users to see when DEBUG is False
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.")
# configuration for bitsandbytes before import
os.environ["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
# The following warnings from bitsandbytes, and probably not that important
# for users to see when DEBUG is False
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.")
_import_structure: dict[str, list[str]] = {
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"],
"_configuration": ["LLMConfig"],
"_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs"],
"_generation": ["StopSequenceCriteria", "StopOnTokens"],
"_quantisation": ["infer_quantisation_config"],
"exceptions": [],
"utils": ["infer_auto_class"],
"models": [],
"client": [],
"bundle": [],
"playground": [],
"testing": [],
"serialisation": ["ggml", "transformers"],
"cli.entrypoint": ["start", "start_grpc", "build", "import_model", "list_models"],
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"], "_configuration": ["LLMConfig"], "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs"], "_generation": ["StopSequenceCriteria", "StopOnTokens"], "_quantisation": ["infer_quantisation_config"], "exceptions": [], "utils": ["infer_auto_class"], "models": [],
"client": [], "bundle": [], "playground": [], "testing": [], "serialisation": ["ggml", "transformers"], "cli.entrypoint": ["start", "start_grpc", "build", "import_model", "list_models"],
# NOTE: models
"models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES", ],
"models.chatglm": ["ChatGLMConfig"],
"models.baichuan": ["BaichuanConfig"],
"models.dolly_v2": ["DollyV2Config"],
"models.falcon": ["FalconConfig"],
"models.flan_t5": ["FlanT5Config"],
"models.gpt_neox": ["GPTNeoXConfig"],
"models.llama": ["LlamaConfig"],
"models.mpt": ["MPTConfig"],
"models.opt": ["OPTConfig"],
"models.stablelm": ["StableLMConfig"],
"models.starcoder": ["StarCoderConfig"],
"models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"], "models.chatglm": ["ChatGLMConfig"], "models.baichuan": ["BaichuanConfig"], "models.dolly_v2": ["DollyV2Config"], "models.falcon": ["FalconConfig"], "models.flan_t5": ["FlanT5Config"], "models.gpt_neox": ["GPTNeoXConfig"],
"models.llama": ["LlamaConfig"], "models.mpt": ["MPTConfig"], "models.opt": ["OPTConfig"], "models.stablelm": ["StableLMConfig"], "models.starcoder": ["StarCoderConfig"],
}
# NOTE: torch and cpm_kernels
try:
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError
except MissingDependencyError:
from .utils import dummy_pt_and_cpm_kernels_objects
_import_structure["utils.dummy_pt_and_cpm_kernels_objects"] = [name for name in dir(dummy_pt_and_cpm_kernels_objects) if not name.startswith("_")]
from .utils import dummy_pt_and_cpm_kernels_objects
_import_structure["utils.dummy_pt_and_cpm_kernels_objects"] = [name for name in dir(dummy_pt_and_cpm_kernels_objects) if not name.startswith("_")]
else:
_import_structure["models.chatglm"].extend(["ChatGLM"])
_import_structure["models.baichuan"].extend(["Baichuan"])
_import_structure["models.chatglm"].extend(["ChatGLM"])
_import_structure["models.baichuan"].extend(["Baichuan"])
try:
if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError
if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError
except MissingDependencyError:
from .utils import dummy_pt_and_einops_objects
_import_structure["utils.dummy_pt_and_einops_objects"] = [name for name in dir(dummy_pt_and_einops_objects) if not name.startswith("_")]
from .utils import dummy_pt_and_einops_objects
_import_structure["utils.dummy_pt_and_einops_objects"] = [name for name in dir(dummy_pt_and_einops_objects) if not name.startswith("_")]
else:
_import_structure["models.falcon"].extend(["Falcon"])
_import_structure["models.falcon"].extend(["Falcon"])
try:
if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError
if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError
except MissingDependencyError:
from .utils import dummy_pt_and_triton_objects
_import_structure["utils.dummy_pt_and_triton_objects"] = [name for name in dir(dummy_pt_and_triton_objects) if not name.startswith("_")]
from .utils import dummy_pt_and_triton_objects
_import_structure["utils.dummy_pt_and_triton_objects"] = [name for name in dir(dummy_pt_and_triton_objects) if not name.startswith("_")]
else:
_import_structure["models.mpt"].extend(["MPT"])
_import_structure["models.mpt"].extend(["MPT"])
try:
if not utils.is_torch_available(): raise MissingDependencyError
if not utils.is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils import dummy_pt_objects
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
from .utils import dummy_pt_objects
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["models.flan_t5"].extend(["FlanT5"])
_import_structure["models.dolly_v2"].extend(["DollyV2"])
_import_structure["models.starcoder"].extend(["StarCoder"])
_import_structure["models.stablelm"].extend(["StableLM"])
_import_structure["models.opt"].extend(["OPT"])
_import_structure["models.gpt_neox"].extend(["GPTNeoX"])
_import_structure["models.llama"].extend(["Llama"])
_import_structure["models.auto"].extend(["AutoLLM", "MODEL_MAPPING"])
_import_structure["models.flan_t5"].extend(["FlanT5"])
_import_structure["models.dolly_v2"].extend(["DollyV2"])
_import_structure["models.starcoder"].extend(["StarCoder"])
_import_structure["models.stablelm"].extend(["StableLM"])
_import_structure["models.opt"].extend(["OPT"])
_import_structure["models.gpt_neox"].extend(["GPTNeoX"])
_import_structure["models.llama"].extend(["Llama"])
_import_structure["models.auto"].extend(["AutoLLM", "MODEL_MAPPING"])
try:
if not utils.is_vllm_available(): raise MissingDependencyError
if not utils.is_vllm_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils import dummy_vllm_objects
_import_structure["utils.dummy_vllm_objects"] = [name for name in dir(dummy_vllm_objects) if not name.startswith("_")]
from .utils import dummy_vllm_objects
_import_structure["utils.dummy_vllm_objects"] = [name for name in dir(dummy_vllm_objects) if not name.startswith("_")]
else:
_import_structure["models.llama"].extend(["VLLMLlama"])
_import_structure["models.auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
_import_structure["models.llama"].extend(["VLLMLlama"])
_import_structure["models.auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
try:
if not utils.is_flax_available(): raise MissingDependencyError
if not utils.is_flax_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils import dummy_flax_objects
_import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")]
from .utils import dummy_flax_objects
_import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")]
else:
_import_structure["models.flan_t5"].extend(["FlaxFlanT5"])
_import_structure["models.opt"].extend(["FlaxOPT"])
_import_structure["models.auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
_import_structure["models.flan_t5"].extend(["FlaxFlanT5"])
_import_structure["models.opt"].extend(["FlaxOPT"])
_import_structure["models.auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
try:
if not utils.is_tf_available(): raise MissingDependencyError
if not utils.is_tf_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils import dummy_tf_objects
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
from .utils import dummy_tf_objects
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
else:
_import_structure["models.flan_t5"].extend(["TFFlanT5"])
_import_structure["models.opt"].extend(["TFOPT"])
_import_structure["models.auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
_import_structure["models.flan_t5"].extend(["TFFlanT5"])
_import_structure["models.opt"].extend(["TFOPT"])
_import_structure["models.auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
# declaration for OpenLLM-related modules
if t.TYPE_CHECKING:
from . import bundle as bundle
from . import cli as cli
from . import client as client
from . import exceptions as exceptions
from . import models as models
from . import playground as playground
from . import serialisation as serialisation
from . import testing as testing
from . import bundle as bundle
from . import cli as cli
from . import client as client
from . import exceptions as exceptions
from . import models as models
from . import playground as playground
from . import serialisation as serialisation
from . import testing as testing
# Specific types import
from ._configuration import LLMConfig as LLMConfig
from ._generation import StopOnTokens as StopOnTokens
from ._generation import StopSequenceCriteria as StopSequenceCriteria
from ._llm import LLM as LLM
from ._llm import LLMRunnable as LLMRunnable
from ._llm import LLMRunner as LLMRunner
from ._llm import Runner as Runner
from ._quantisation import infer_quantisation_config as infer_quantisation_config
from ._schema import EmbeddingsOutput as EmbeddingsOutput
from ._schema import GenerationInput as GenerationInput
from ._schema import GenerationOutput as GenerationOutput
from ._schema import MetadataOutput as MetadataOutput
from ._schema import unmarshal_vllm_outputs as unmarshal_vllm_outputs
from .cli.entrypoint import build as build
from .cli.entrypoint import import_model as import_model
from .cli.entrypoint import list_models as list_models
from .cli.entrypoint import start as start
from .cli.entrypoint import start_grpc as start_grpc
from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
from .models.auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
from .models.auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
from .models.auto import AutoConfig as AutoConfig
from .models.baichuan import BaichuanConfig as BaichuanConfig
from .models.chatglm import ChatGLMConfig as ChatGLMConfig
from .models.dolly_v2 import DollyV2Config as DollyV2Config
from .models.falcon import FalconConfig as FalconConfig
from .models.flan_t5 import FlanT5Config as FlanT5Config
from .models.gpt_neox import GPTNeoXConfig as GPTNeoXConfig
from .models.llama import LlamaConfig as LlamaConfig
from .models.mpt import MPTConfig as MPTConfig
from .models.opt import OPTConfig as OPTConfig
from .models.stablelm import StableLMConfig as StableLMConfig
from .models.starcoder import StarCoderConfig as StarCoderConfig
from .serialisation import ggml as ggml
from .serialisation import transformers as transformers
from .utils import infer_auto_class as infer_auto_class
# Specific types import
from ._configuration import LLMConfig as LLMConfig
from ._generation import StopOnTokens as StopOnTokens
from ._generation import StopSequenceCriteria as StopSequenceCriteria
from ._llm import LLM as LLM
from ._llm import LLMRunnable as LLMRunnable
from ._llm import LLMRunner as LLMRunner
from ._llm import Runner as Runner
from ._quantisation import infer_quantisation_config as infer_quantisation_config
from ._schema import EmbeddingsOutput as EmbeddingsOutput
from ._schema import GenerationInput as GenerationInput
from ._schema import GenerationOutput as GenerationOutput
from ._schema import MetadataOutput as MetadataOutput
from ._schema import unmarshal_vllm_outputs as unmarshal_vllm_outputs
from .cli.entrypoint import build as build
from .cli.entrypoint import import_model as import_model
from .cli.entrypoint import list_models as list_models
from .cli.entrypoint import start as start
from .cli.entrypoint import start_grpc as start_grpc
from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
from .models.auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
from .models.auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
from .models.auto import AutoConfig as AutoConfig
from .models.baichuan import BaichuanConfig as BaichuanConfig
from .models.chatglm import ChatGLMConfig as ChatGLMConfig
from .models.dolly_v2 import DollyV2Config as DollyV2Config
from .models.falcon import FalconConfig as FalconConfig
from .models.flan_t5 import FlanT5Config as FlanT5Config
from .models.gpt_neox import GPTNeoXConfig as GPTNeoXConfig
from .models.llama import LlamaConfig as LlamaConfig
from .models.mpt import MPTConfig as MPTConfig
from .models.opt import OPTConfig as OPTConfig
from .models.stablelm import StableLMConfig as StableLMConfig
from .models.starcoder import StarCoderConfig as StarCoderConfig
from .serialisation import ggml as ggml
from .serialisation import transformers as transformers
from .utils import infer_auto_class as infer_auto_class
# NOTE: torch and cpm_kernels
try:
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_pt_and_cpm_kernels_objects import *
else:
from .models.baichuan import Baichuan as Baichuan
from .models.chatglm import ChatGLM as ChatGLM
# NOTE: torch and einops
try:
if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_pt_and_einops_objects import *
else:
from .models.falcon import Falcon as Falcon
# NOTE: torch and triton
try:
if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_pt_and_triton_objects import *
else:
from .models.mpt import MPT as MPT
try:
if not utils.is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_pt_objects import *
else:
from .models.auto import MODEL_MAPPING as MODEL_MAPPING
from .models.auto import AutoLLM as AutoLLM
from .models.dolly_v2 import DollyV2 as DollyV2
from .models.flan_t5 import FlanT5 as FlanT5
from .models.gpt_neox import GPTNeoX as GPTNeoX
from .models.llama import Llama as Llama
from .models.opt import OPT as OPT
from .models.stablelm import StableLM as StableLM
from .models.starcoder import StarCoder as StarCoder
try:
if not utils.is_vllm_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_vllm_objects import *
else:
from .models.auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING
from .models.auto import AutoVLLM as AutoVLLM
from .models.llama import VLLMLlama as VLLMLlama
from .models.opt import VLLMOPT as VLLMOPT
try:
if not utils.is_flax_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_flax_objects import *
else:
from .models.auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING
from .models.auto import AutoFlaxLLM as AutoFlaxLLM
from .models.flan_t5 import FlaxFlanT5 as FlaxFlanT5
from .models.opt import FlaxOPT as FlaxOPT
try:
if not utils.is_tf_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_tf_objects import *
else:
from .models.auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING
from .models.auto import AutoTFLLM as AutoTFLLM
from .models.flan_t5 import TFFlanT5 as TFFlanT5
from .models.opt import TFOPT as TFOPT
# NOTE: torch and cpm_kernels
try:
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_pt_and_cpm_kernels_objects import *
else:
from .models.baichuan import Baichuan as Baichuan
from .models.chatglm import ChatGLM as ChatGLM
# NOTE: torch and einops
try:
if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_pt_and_einops_objects import *
else:
from .models.falcon import Falcon as Falcon
# NOTE: torch and triton
try:
if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_pt_and_triton_objects import *
else:
from .models.mpt import MPT as MPT
try:
if not utils.is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_pt_objects import *
else:
from .models.auto import MODEL_MAPPING as MODEL_MAPPING
from .models.auto import AutoLLM as AutoLLM
from .models.dolly_v2 import DollyV2 as DollyV2
from .models.flan_t5 import FlanT5 as FlanT5
from .models.gpt_neox import GPTNeoX as GPTNeoX
from .models.llama import Llama as Llama
from .models.opt import OPT as OPT
from .models.stablelm import StableLM as StableLM
from .models.starcoder import StarCoder as StarCoder
try:
if not utils.is_vllm_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_vllm_objects import *
else:
from .models.auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING
from .models.auto import AutoVLLM as AutoVLLM
from .models.llama import VLLMLlama as VLLMLlama
from .models.opt import VLLMOPT as VLLMOPT
try:
if not utils.is_flax_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_flax_objects import *
else:
from .models.auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING
from .models.auto import AutoFlaxLLM as AutoFlaxLLM
from .models.flan_t5 import FlaxFlanT5 as FlaxFlanT5
from .models.opt import FlaxOPT as FlaxOPT
try:
if not utils.is_tf_available(): raise MissingDependencyError
except MissingDependencyError:
from .utils.dummy_tf_objects import *
else:
from .models.auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING
from .models.auto import AutoTFLLM as AutoTFLLM
from .models.flan_t5 import TFFlanT5 as TFFlanT5
from .models.opt import TFOPT as TFOPT
else: sys.modules[__name__] = utils.LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, doc=__doc__,
extra_objects={
else:
sys.modules[__name__] = utils.LazyModule(
__name__,
globals()["__file__"], _import_structure, module_spec=__spec__, doc=__doc__, extra_objects={
# The below is a special mapping that allows openllm to be used as a dictionary.
# This is purely for convenience sake, and should not be used in performance critcal
# code. This is also not considered as a public API.
"__openllm_special__": {"flax": "AutoFlaxLLM", "tf": "AutoTFLLM", "pt": "AutoLLM", "vllm": "AutoVLLM"},
})
}
)

View File

@@ -20,5 +20,5 @@ To start any OpenLLM model:
openllm start <model_name> --options ...
"""
if __name__ == "__main__":
from openllm.cli.entrypoint import cli
cli()
from openllm.cli.entrypoint import cli
cli()

View File

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generation utilities to be reused throughout."""
from __future__ import annotations
import typing as t
@@ -19,29 +18,12 @@ import typing as t
import transformers
if t.TYPE_CHECKING:
import torch
import torch
class StopSequenceCriteria(transformers.StoppingCriteria):
"""This class used to stop generation when a seq of tokens are met.
Args:
stop_sequences: `str` or `list[str]` of the sequence (list of sequences) on which to stop execution.
tokenizer: Tokenizer to be used to decode the model outputs.
"""
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer):
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
self.stop_sequences = stop_sequences
self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **attrs: t.Any) -> bool:
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer):
if isinstance(stop_sequences, str): stop_sequences = [stop_sequences]
self.stop_sequences,self.tokenizer = stop_sequences, tokenizer
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **attrs: t.Any) -> bool: return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences)
class StopOnTokens(transformers.StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: t.Any) -> bool:
stop_ids = {50278, 50279, 50277, 1, 0}
return t.cast(int, input_ids[0][-1]) in stop_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: t.Any) -> bool: return t.cast(int, input_ids[0][-1]) in {50278, 50279, 50277, 1, 0}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -16,33 +16,25 @@ import string
import typing as t
class PromptFormatter(string.Formatter):
"""This PromptFormatter is largely based on langchain's implementation."""
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
if len(args) > 0:
raise ValueError("Positional arguments are not supported")
return super().vformat(format_string, args, kwargs)
def check_unused_args(
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
) -> None:
"""Check if extra params is passed."""
extras = set(kwargs).difference(used_args)
if extras:
raise KeyError(f"Extra params passed: {extras}")
def extract_template_variables(self, template: str) -> t.Sequence[str]:
"""Extract template variables from a template string."""
return [field[1] for field in self.parse(template) if field[1] is not None]
"""This PromptFormatter is largely based on langchain's implementation."""
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
if len(args) > 0: raise ValueError("Positional arguments are not supported")
return super().vformat(format_string, args, kwargs)
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
extras = set(kwargs).difference(used_args)
if extras: raise KeyError(f"Extra params passed: {extras}")
def extract_template_variables(self, template: str) -> t.Sequence[str]: return [field[1] for field in self.parse(template) if field[1] is not None]
default_formatter = PromptFormatter()
def process_prompt(prompt: str, template: str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str:
# Currently, all default prompt will always have `instruction` key.
if not use_prompt_template: return prompt
elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'")
template_variables = default_formatter.extract_template_variables(template)
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
try: return template.format(instruction=prompt, **prompt_variables)
except KeyError as e: raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None
# Currently, all default prompt will always have `instruction` key.
if not use_prompt_template: return prompt
elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'")
template_variables = default_formatter.extract_template_variables(template)
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
try:
return template.format(instruction=prompt, **prompt_variables)
except KeyError as e:
raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None

View File

@@ -25,119 +25,78 @@ from .utils import pkg
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import overload
from typing import overload
else:
from typing_extensions import overload
from typing_extensions import overload
if t.TYPE_CHECKING:
import auto_gptq as autogptq
import torch
import auto_gptq as autogptq
import torch
import openllm
import transformers
import openllm
import transformers
from ._types import DictStrAny
from ._types import DictStrAny
else:
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
torch = LazyLoader("torch", globals(), "torch")
transformers = LazyLoader("transformers", globals(), "transformers")
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
torch = LazyLoader("torch", globals(), "torch")
transformers = LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
QuantiseMode = t.Literal["int8", "int4", "gptq"]
# fmt: off
@overload
def infer_quantisation_config(
cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["int8", "int4"], **attrs: t.Any
) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
...
def infer_quantisation_config(cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["int8", "int4"], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]: ...
@overload
def infer_quantisation_config(
cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["gptq"], **attrs: t.Any
) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]:
...
def infer_quantisation_config(cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["gptq"], **attrs: t.Any) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]: ...
# fmt: on
def infer_quantisation_config(cls: type[openllm.LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]:
# 8 bit configuration
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
autogptq_attrs: DictStrAny = {"bits": attrs.pop("gptq_bits", 4), "group_size": attrs.pop("gptq_group_size", -1), "damp_percent": attrs.pop("gptq_damp_percent", 0.01), "desc_act": attrs.pop("gptq_desc_act", True), "sym": attrs.pop("gptq_sym", True), "true_sequential": attrs.pop("gptq_true_sequential", True),}
def infer_quantisation_config(
cls: type[openllm.LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any
) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]:
# 8 bit configuration
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
if int8_skip_modules is None: int8_skip_modules = []
if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm":
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
int8_skip_modules.append("lm_head")
return transformers.BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload, llm_int8_threshhold=int8_threshold, llm_int8_skip_modules=int8_skip_modules, llm_int8_has_fp16_weight=int8_has_fp16_weight,)
autogptq_attrs: DictStrAny = {
"bits": attrs.pop("gptq_bits", 4),
"group_size": attrs.pop("gptq_group_size", -1),
"damp_percent": attrs.pop("gptq_damp_percent", 0.01),
"desc_act": attrs.pop("gptq_desc_act", True),
"sym": attrs.pop("gptq_sym", True),
"true_sequential": attrs.pop("gptq_true_sequential", True),
}
# 4 bit configuration
int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16)
int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4")
int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True)
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
if int8_skip_modules is None:
int8_skip_modules = []
if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm":
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
int8_skip_modules.append("lm_head")
return transformers.BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
llm_int8_threshhold=int8_threshold,
llm_int8_skip_modules=int8_skip_modules,
llm_int8_has_fp16_weight=int8_has_fp16_weight,
)
# 4 bit configuration
int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16)
int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4")
int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True)
# NOTE: Quantization setup
# quantize is a openllm.LLM feature, where we can quantize the model
# with bitsandbytes or quantization aware training.
if not is_bitsandbytes_available():
raise RuntimeError(
"Quantization requires bitsandbytes to be installed. Make "
"sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
)
if quantise == "int8":
quantisation_config = create_int8_config(int8_skip_modules)
elif quantise == "int4":
if is_transformers_supports_kbit():
quantisation_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=int4_compute_dtype,
bnb_4bit_quant_type=int4_quant_type,
bnb_4bit_use_double_quant=int4_use_double_quant,
)
else:
logger.warning(
"'quantize' is set to int4, while the current transformers version %s does not support "
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
"make sure to install the latest version of transformers either via PyPI or "
"from git source: 'pip install git+https://github.com/huggingface/transformers'.",
pkg.pkg_version_info("transformers"),
)
logger.warning("OpenLLM will fallback to 8-bit quantization.")
quantisation_config = create_int8_config(int8_skip_modules)
elif quantise == "gptq":
if not is_autogptq_available():
logger.warning(
"'quantize=\"gptq\"' requires 'auto-gptq' to be installed (not available with local environment)."
" Make sure to have 'auto-gptq' available locally: 'pip install \"openllm[gptq]\"'. OpenLLM will fallback "
"to int8 with bitsandbytes."
)
quantisation_config = create_int8_config(int8_skip_modules)
else:
quantisation_config = autogptq.BaseQuantizeConfig(**autogptq_attrs)
# NOTE: Quantization setup
# quantize is a openllm.LLM feature, where we can quantize the model
# with bitsandbytes or quantization aware training.
if not is_bitsandbytes_available(): raise RuntimeError("Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
if quantise == "int8": quantisation_config = create_int8_config(int8_skip_modules)
elif quantise == "int4":
if is_transformers_supports_kbit(): quantisation_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=int4_compute_dtype, bnb_4bit_quant_type=int4_quant_type, bnb_4bit_use_double_quant=int4_use_double_quant,)
else:
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.")
return quantisation_config, attrs
logger.warning(
"'quantize' is set to int4, while the current transformers version %s does not support "
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
"make sure to install the latest version of transformers either via PyPI or "
"from git source: 'pip install git+https://github.com/huggingface/transformers'.", pkg.pkg_version_info("transformers"),
)
logger.warning("OpenLLM will fallback to 8-bit quantization.")
quantisation_config = create_int8_config(int8_skip_modules)
elif quantise == "gptq":
if not is_autogptq_available():
logger.warning("'quantize=\"gptq\"' requires 'auto-gptq' to be installed (not available with local environment)."
" Make sure to have 'auto-gptq' available locally: 'pip install \"openllm[gptq]\"'. OpenLLM will fallback "
"to int8 with bitsandbytes.")
quantisation_config = create_int8_config(int8_skip_modules)
else:
quantisation_config = autogptq.BaseQuantizeConfig(**autogptq_attrs)
else:
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.")
return quantisation_config, attrs

View File

@@ -27,109 +27,71 @@ from .utils import bentoml_cattr
from .utils import requires_dependencies
if t.TYPE_CHECKING:
import vllm
import vllm
from ._types import DictStrAny
from ._types import DictStrAny
else:
DictStrAny = dict
vllm = LazyLoader("vllm", globals(), "vllm")
DictStrAny = dict
vllm = LazyLoader("vllm", globals(), "vllm")
@attr.frozen(slots=True)
class GenerationInput:
prompt: str
"""The prompt to be sent to system."""
prompt: str
"""The prompt to be sent to system."""
llm_config: LLMConfig
"""A mapping of given LLM configuration values for given system."""
@staticmethod
def convert_llm_config(data: dict[str, t.Any] | LLMConfig, cls: type[LLMConfig] | None = None) -> LLMConfig:
if isinstance(data, LLMConfig): return data
elif LazyType(DictStrAny).isinstance(data):
if cls is None: raise ValueError("'cls' must pass if given data is a dictionary.")
return cls(**data)
else:
raise RuntimeError(f"Type {type(data)} is not yet supported.")
llm_config: LLMConfig
"""A mapping of given LLM configuration values for given system."""
@staticmethod
def convert_llm_config(data: dict[str, t.Any] | LLMConfig, cls: type[LLMConfig] | None = None) -> LLMConfig:
if isinstance(data, LLMConfig):
return data
elif LazyType(DictStrAny).isinstance(data):
if cls is None:
raise ValueError("'cls' must pass if given data is a dictionary.")
return cls(**data)
else:
raise RuntimeError(f"Type {type(data)} is not yet supported.")
@classmethod
def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]:
from .models.auto import AutoConfig
llm_config = AutoConfig.for_model(model_name, **attrs)
return attr.make_class(
inflection.camelize(llm_config["model_name"]) + "GenerationInput",
attrs={
"prompt": attr.field(type=str),
"llm_config": attr.field(
type=llm_config.__class__,
default=llm_config,
converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__),
),
},
)
def model_dump(self) -> dict[str, t.Any]:
return {"prompt": self.prompt, "llm_config": self.llm_config.model_dump(flatten=True)}
@classmethod
def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]:
from .models.auto import AutoConfig
llm_config = AutoConfig.for_model(model_name, **attrs)
return attr.make_class(inflection.camelize(llm_config["model_name"]) + "GenerationInput", attrs={"prompt": attr.field(type=str), "llm_config": attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__))})
def model_dump(self) -> dict[str, t.Any]:
return {"prompt": self.prompt, "llm_config": self.llm_config.model_dump(flatten=True)}
@attr.frozen(slots=True)
class GenerationOutput:
responses: t.List[t.Any]
"""A list of responses from the system."""
responses: t.List[t.Any]
"""A list of responses from the system."""
configuration: t.Dict[str, t.Any]
"""A mapping of configuration values for given system."""
@property
def marshaled_config(self) -> GenerationConfig:
return bentoml_cattr.structure(self.configuration, GenerationConfig)
configuration: t.Dict[str, t.Any]
"""A mapping of configuration values for given system."""
@property
def marshaled_config(self) -> GenerationConfig:
return bentoml_cattr.structure(self.configuration, GenerationConfig)
@property
def unmarshaled(self) -> dict[str, t.Any]:
return bentoml_cattr.unstructure(self)
def __getitem__(self, key: str) -> t.Any:
if hasattr(self, key): return getattr(self, key)
elif key in self.configuration: return self.configuration[key]
else: raise KeyError(key)
@property
def unmarshaled(self) -> dict[str, t.Any]:
return bentoml_cattr.unstructure(self)
def __getitem__(self, key: str) -> t.Any:
if hasattr(self, key): return getattr(self, key)
elif key in self.configuration: return self.configuration[key]
else: raise KeyError(key)
@attr.frozen(slots=True)
class MetadataOutput:
model_id: str
timeout: int
model_name: str
framework: str
configuration: str
supports_embeddings: bool
supports_hf_agent: bool
model_id: str
timeout: int
model_name: str
framework: str
configuration: str
supports_embeddings: bool
supports_hf_agent: bool
@attr.frozen(slots=True)
class EmbeddingsOutput:
embeddings: t.List[t.List[float]]
num_tokens: int
embeddings: t.List[t.List[float]]
num_tokens: int
@requires_dependencies("vllm", extra="vllm")
def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> DictStrAny:
return dict(
request_id=request_output.request_id,
prompt=request_output.prompt,
finished=request_output.finished,
prompt_token_ids=request_output.prompt_token_ids,
outputs=[
dict(
index=it.index,
text=it.text,
token_ids=it.token_ids,
cumulative_logprob=it.cumulative_logprob,
logprobs=it.logprobs,
finish_reason=it.finish_reason,
)
for it in request_output.outputs
],
)
return dict(request_id=request_output.request_id, prompt=request_output.prompt, finished=request_output.finished, prompt_token_ids=request_output.prompt_token_ids, outputs=[dict(index=it.index, text=it.text, token_ids=it.token_ids, cumulative_logprob=it.cumulative_logprob, logprobs=it.logprobs, finish_reason=it.finish_reason) for it in request_output.outputs],)

View File

@@ -25,8 +25,8 @@ from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
if t.TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response
from starlette.requests import Request
from starlette.responses import Response
# The following warnings from bitsandbytes, and probably not that important for users to see
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
@@ -36,37 +36,59 @@ adapter_map = os.environ.get("OPENLLM_ADAPTER_MAP", """{__model_adapter_map__}""
llm_config = openllm.AutoConfig.for_model(model)
runner = openllm.Runner(model, llm_config=llm_config, ensure_available=False, adapter_map=orjson.loads(adapter_map))
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[runner])
@svc.api(input=bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True)}), output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)}), route="/v1/generate")
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
qa_inputs = openllm.GenerationInput.for_model(model)(**input_dict)
config = qa_inputs.llm_config.model_dump()
responses = await runner.generate.async_run(qa_inputs.prompt, **config)
return openllm.GenerationOutput(responses=responses, configuration=config)
qa_inputs = openllm.GenerationInput.for_model(model)(**input_dict)
config = qa_inputs.llm_config.model_dump()
responses = await runner.generate.async_run(qa_inputs.prompt, **config)
return openllm.GenerationOutput(responses=responses, configuration=config)
@svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample({"model_id": runner.llm.model_id, "timeout": 3600, "model_name": llm_config["model_name"], "framework": "pt", "configuration": "", "supports_embeddings": runner.supports_embeddings, "supports_hf_agent": runner.supports_hf_agent}), route="/v1/metadata")
def metadata_v1(_: str) -> openllm.MetadataOutput: return openllm.MetadataOutput(model_id=runner.llm.model_id, timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent,)
def metadata_v1(_: str) -> openllm.MetadataOutput:
return openllm.MetadataOutput(model_id=runner.llm.model_id, timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent,)
if runner.supports_embeddings:
@svc.api(input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20}), route="/v1/embeddings")
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
responses = await runner.embeddings.async_run(phrases)
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"])
@svc.api(
input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({
"embeddings": [
0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918,
0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076
], "num_tokens": 20
}), route="/v1/embeddings"
)
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
responses = await runner.embeddings.async_run(phrases)
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"])
if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
@attr.define
class HfAgentInput:
inputs: str
parameters: t.Dict[str, t.Any]
async def hf_agent(request: Request) -> Response:
json_str = await request.body()
try: input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), HfAgentInput)
except orjson.JSONDecodeError as err: raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
stop = input_data.parameters.pop("stop", ["\n"])
try: return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
except NotImplementedError: return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])])
svc.mount_asgi_app(hf_app, path="/hf")
@attr.define
class HfAgentInput:
inputs: str
parameters: t.Dict[str, t.Any]
async def hf_agent(request: Request) -> Response:
json_str = await request.body()
try:
input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), HfAgentInput)
except orjson.JSONDecodeError as err:
raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
stop = input_data.parameters.pop("stop", ["\n"])
try:
return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
except NotImplementedError:
return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])])
svc.mount_asgi_app(hf_app, path="/hf")
async def list_adapter_v1(_: Request) -> Response:
res: dict[str, t.Any] = {}
if runner.peft_adapters["success"] is True: res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()}
res.update({"success": runner.peft_adapters["success"], "error_msg": runner.peft_adapters["error_msg"]})
return JSONResponse(res, status_code=200)
res: dict[str, t.Any] = {}
if runner.peft_adapters["success"] is True: res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()}
res.update({"success": runner.peft_adapters["success"], "error_msg": runner.peft_adapters["error_msg"]})
return JSONResponse(res, status_code=200)
adapters_app_v1 = Starlette(debug=True, routes=[Route("/adapters", list_adapter_v1, methods=["GET"])])
svc.mount_asgi_app(adapters_app_v1, path="/v1")

View File

@@ -35,223 +35,220 @@ from .utils import LazyType
from .utils import ReprMixin
if t.TYPE_CHECKING:
ListIntStr = list[int | str]
class DynResource(bentoml.Resource[t.List[str]], resource_id=""):
resource_id: t.ClassVar[str]
ListIntStr = list[int | str]
class DynResource(bentoml.Resource[t.List[str]], resource_id=""):
resource_id: t.ClassVar[str]
else:
DynResource = bentoml.Resource[t.List[str]]
ListIntStr = list
DynResource = bentoml.Resource[t.List[str]]
ListIntStr = list
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import overload
from typing import overload
else:
from typing_extensions import overload
from typing_extensions import overload
logger = logging.getLogger(__name__)
def _strtoul(s: str) -> int:
"""Return -1 or positive integer sequence string starts with,."""
if not s: return -1
idx = 0
for idx, c in enumerate(s):
if not (c.isdigit() or (idx == 0 and c in "+-")): break
if idx + 1 == len(s): idx += 1 # noqa: PLW2901
# NOTE: idx will be set via enumerate
return int(s[:idx]) if idx > 0 else -1
"""Return -1 or positive integer sequence string starts with,."""
if not s: return -1
idx = 0
for idx, c in enumerate(s):
if not (c.isdigit() or (idx == 0 and c in "+-")): break
if idx + 1 == len(s): idx += 1 # noqa: PLW2901
# NOTE: idx will be set via enumerate
return int(s[:idx]) if idx > 0 else -1
def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
rcs: list[str] = []
for elem in lst.split(","):
# Repeated id results in empty set
if elem in rcs: return []
# Anything other but prefix is ignored
if not elem.startswith(prefix): break
rcs.append(elem)
return rcs
rcs: list[str] = []
for elem in lst.split(","):
# Repeated id results in empty set
if elem in rcs: return []
# Anything other but prefix is ignored
if not elem.startswith(prefix): break
rcs.append(elem)
return rcs
_STACK_LEVEL = 3
@overload
def _parse_visible_devices(default_var: str | None = ..., respect_env: t.Literal[True] = True) -> list[str] | None: ...
def _parse_visible_devices(default_var: str | None = ..., respect_env: t.Literal[True] = True) -> list[str] | None:
...
@overload
def _parse_visible_devices(default_var: str = ..., respect_env: t.Literal[False] = ...) -> list[str]: ...
def _parse_visible_devices(default_var: str = ..., respect_env: t.Literal[False] = ...) -> list[str]:
...
def _parse_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None:
"""CUDA_VISIBLE_DEVICES aware with default var for parsing spec."""
if respect_env:
spec = os.getenv("CUDA_VISIBLE_DEVICES", default_var)
if not spec: return None
else:
if default_var is None: raise ValueError("spec is required to be not None when parsing spec.")
spec = default_var
"""CUDA_VISIBLE_DEVICES aware with default var for parsing spec."""
if respect_env:
spec = os.getenv("CUDA_VISIBLE_DEVICES", default_var)
if not spec: return None
else:
if default_var is None: raise ValueError("spec is required to be not None when parsing spec.")
spec = default_var
if spec.startswith("GPU-"): return _parse_list_with_prefix(spec, "GPU-")
if spec.startswith("MIG-"): return _parse_list_with_prefix(spec, "MIG-")
# XXX: We to somehow handle cases such as '100m'
# CUDA_VISIBLE_DEVICES uses something like strtoul
# which makes `1gpu2,2ampere` is equivalent to `1,2`
rc: list[int] = []
for el in spec.split(","):
x = _strtoul(el.strip())
# Repeated ordinal results in empty set
if x in rc: return []
# Negative value aborts the sequence
if x < 0: break
rc.append(x)
return [str(i) for i in rc]
if spec.startswith("GPU-"): return _parse_list_with_prefix(spec, "GPU-")
if spec.startswith("MIG-"): return _parse_list_with_prefix(spec, "MIG-")
# XXX: We to somehow handle cases such as '100m'
# CUDA_VISIBLE_DEVICES uses something like strtoul
# which makes `1gpu2,2ampere` is equivalent to `1,2`
rc: list[int] = []
for el in spec.split(","):
x = _strtoul(el.strip())
# Repeated ordinal results in empty set
if x in rc: return []
# Negative value aborts the sequence
if x < 0: break
rc.append(x)
return [str(i) for i in rc]
def _from_system(cls: type[DynResource]) -> list[str]:
"""Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation.
"""Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation.
It relies on torch.cuda implementation and in turns respect CUDA_VISIBLE_DEVICES.
"""
visible_devices = _parse_visible_devices()
if visible_devices is None:
if cls.resource_id == "amd.com/gpu":
if not psutil.LINUX:
if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
return []
It relies on torch.cuda implementation and in turns respect CUDA_VISIBLE_DEVICES.
"""
visible_devices = _parse_visible_devices()
if visible_devices is None:
if cls.resource_id == "amd.com/gpu":
if not psutil.LINUX:
if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
return []
# ROCm does not currently have the rocm_smi wheel.
# So we need to use the ctypes bindings directly.
# we don't want to use CLI because parsing is a pain.
sys.path.append("/opt/rocm/libexec/rocm_smi")
try:
from ctypes import byref
from ctypes import c_uint32
# ROCm does not currently have the rocm_smi wheel.
# So we need to use the ctypes bindings directly.
# we don't want to use CLI because parsing is a pain.
sys.path.append("/opt/rocm/libexec/rocm_smi")
try:
from ctypes import byref
from ctypes import c_uint32
# refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py
from rsmiBindings import rocmsmi
from rsmiBindings import rsmi_status_t
device_count = c_uint32(0)
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)]
return []
# In this case the binary is not found, returning empty list
except (ModuleNotFoundError, ImportError): return []
finally: sys.path.remove("/opt/rocm/libexec/rocm_smi")
else:
try:
from cuda import cuda
cuda.cuInit(0)
_, dev = cuda.cuDeviceGetCount()
return [str(i) for i in range(dev)]
except (ImportError, RuntimeError, AttributeError): return []
return visible_devices
# refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py
from rsmiBindings import rocmsmi
from rsmiBindings import rsmi_status_t
device_count = c_uint32(0)
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)]
return []
# In this case the binary is not found, returning empty list
except (ModuleNotFoundError, ImportError):
return []
finally:
sys.path.remove("/opt/rocm/libexec/rocm_smi")
else:
try:
from cuda import cuda
cuda.cuInit(0)
_, dev = cuda.cuDeviceGetCount()
return [str(i) for i in range(dev)]
except (ImportError, RuntimeError, AttributeError):
return []
return visible_devices
@overload
def _from_spec(cls: type[DynResource], spec: int) -> list[str]: ...
def _from_spec(cls: type[DynResource], spec: int) -> list[str]:
...
@overload
def _from_spec(cls: type[DynResource], spec: ListIntStr) -> list[str]: ...
def _from_spec(cls: type[DynResource], spec: ListIntStr) -> list[str]:
...
@overload
def _from_spec(cls: type[DynResource], spec: str) -> list[str]: ...
def _from_spec(cls: type[DynResource], spec: str) -> list[str]:
...
def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]:
"""Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation.
The parser behaves similar to how PyTorch handles CUDA_VISIBLE_DEVICES. This means within
BentoML's resource configuration, its behaviour is similar to CUDA_VISIBLE_DEVICES.
"""
if isinstance(spec, int):
if spec in (-1, 0): return []
if spec < -1: raise ValueError("Spec cannot be < -1.")
return [str(i) for i in range(spec)]
elif isinstance(spec, str):
if not spec: return []
if spec.isdigit(): spec = ",".join([str(i) for i in range(_strtoul(spec))])
return _parse_visible_devices(spec, respect_env=False)
elif LazyType(ListIntStr).isinstance(spec): return [str(x) for x in spec]
else: raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
"""Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation.
The parser behaves similar to how PyTorch handles CUDA_VISIBLE_DEVICES. This means within
BentoML's resource configuration, its behaviour is similar to CUDA_VISIBLE_DEVICES.
"""
if isinstance(spec, int):
if spec in (-1, 0): return []
if spec < -1: raise ValueError("Spec cannot be < -1.")
return [str(i) for i in range(spec)]
elif isinstance(spec, str):
if not spec: return []
if spec.isdigit(): spec = ",".join([str(i) for i in range(_strtoul(spec))])
return _parse_visible_devices(spec, respect_env=False)
elif LazyType(ListIntStr).isinstance(spec):
return [str(x) for x in spec]
else:
raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
def _raw_device_uuid_nvml() -> list[str] | None:
"""Return list of device UUID as reported by NVML or None if NVML discovery/initialization failed."""
from ctypes import CDLL
from ctypes import byref
from ctypes import c_int
from ctypes import c_void_p
from ctypes import create_string_buffer
"""Return list of device UUID as reported by NVML or None if NVML discovery/initialization failed."""
from ctypes import CDLL
from ctypes import byref
from ctypes import c_int
from ctypes import c_void_p
from ctypes import create_string_buffer
try: nvml_h = CDLL("libnvidia-ml.so.1")
except Exception:
warnings.warn("Failed to find nvidia binding", stacklevel=_STACK_LEVEL)
return None
try:
nvml_h = CDLL("libnvidia-ml.so.1")
except Exception:
warnings.warn("Failed to find nvidia binding", stacklevel=_STACK_LEVEL)
return None
rc = nvml_h.nvmlInit()
rc = nvml_h.nvmlInit()
if rc != 0:
warnings.warn("Can't initialize NVML", stacklevel=_STACK_LEVEL)
return None
dev_count = c_int(-1)
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
if rc != 0:
warnings.warn("Failed to get available device from system.", stacklevel=_STACK_LEVEL)
return None
uuids: list[str] = []
for idx in range(dev_count.value):
dev_id = c_void_p()
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
if rc != 0:
warnings.warn("Can't initialize NVML", stacklevel=_STACK_LEVEL)
return None
dev_count = c_int(-1)
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
warnings.warn(f"Failed to get device handle for {idx}", stacklevel=_STACK_LEVEL)
return None
buf_len = 96
buf = create_string_buffer(buf_len)
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
if rc != 0:
warnings.warn("Failed to get available device from system.", stacklevel=_STACK_LEVEL)
return None
uuids: list[str] = []
for idx in range(dev_count.value):
dev_id = c_void_p()
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
if rc != 0:
warnings.warn(f"Failed to get device handle for {idx}", stacklevel=_STACK_LEVEL)
return None
buf_len = 96
buf = create_string_buffer(buf_len)
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
if rc != 0:
warnings.warn(f"Failed to get device UUID for {idx}", stacklevel=_STACK_LEVEL)
return None
uuids.append(buf.raw.decode("ascii").strip("\0"))
del nvml_h
return uuids
warnings.warn(f"Failed to get device UUID for {idx}", stacklevel=_STACK_LEVEL)
return None
uuids.append(buf.raw.decode("ascii").strip("\0"))
del nvml_h
return uuids
def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
if cls.resource_id == "amd.com/gpu":
raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'")
if not all(isinstance(i, str) for i in val): raise ValueError("Input list should be all string type.")
if cls.resource_id == "amd.com/gpu":
raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'")
if not all(isinstance(i, str) for i in val): raise ValueError("Input list should be all string type.")
try:
from cuda import cuda
err, *_ = cuda.cuInit(0)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Failed to initialise CUDA runtime binding.")
# correctly parse handle
for el in val:
if el.startswith("GPU-") or el.startswith("MIG-"):
uuids = _raw_device_uuid_nvml()
if uuids is None: raise ValueError("Failed to parse available GPUs UUID")
if el not in uuids: raise ValueError(f"Given UUID {el} is not found with available UUID (available: {uuids})")
elif el.isdigit():
err, _ = cuda.cuDeviceGet(int(el))
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f"Failed to get device {el}")
except (ImportError, RuntimeError):
pass
try:
from cuda import cuda
err, *_ = cuda.cuInit(0)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Failed to initialise CUDA runtime binding.")
# correctly parse handle
for el in val:
if el.startswith("GPU-") or el.startswith("MIG-"):
uuids = _raw_device_uuid_nvml()
if uuids is None: raise ValueError("Failed to parse available GPUs UUID")
if el not in uuids: raise ValueError(f"Given UUID {el} is not found with available UUID (available: {uuids})")
elif el.isdigit():
err, _ = cuda.cuDeviceGet(int(el))
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f"Failed to get device {el}")
except (ImportError, RuntimeError):
pass
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[DynResource]:
return types.new_class(
name,
(DynResource, ReprMixin),
{"resource_id": resource_kind},
lambda ns: ns.update(
{
"resource_id": resource_kind,
"from_spec": classmethod(_from_spec),
"from_system": classmethod(_from_system),
"validate": classmethod(_validate),
"__repr_keys__": property(lambda _: {"resource_id"}),
"__doc__": inspect.cleandoc(docstring),
"__module__": "openllm._strategies",
}
),
)
return types.new_class(
name, (DynResource, ReprMixin), {"resource_id": resource_kind}, lambda ns: ns.update({"resource_id": resource_kind, "from_spec": classmethod(_from_spec), "from_system": classmethod(_from_system), "validate": classmethod(_validate), "__repr_keys__": property(lambda _: {"resource_id"}), "__doc__": inspect.cleandoc(docstring), "__module__": "openllm._strategies",}),
)
# NOTE: we need to hint these t.Literal since mypy is to dumb to infer this as literal :facepalm:
_TPU_RESOURCE: t.Literal["cloud-tpus.google.com/v2"] = "cloud-tpus.google.com/v2"
@@ -259,159 +256,146 @@ _AMD_GPU_RESOURCE: t.Literal["amd.com/gpu"] = "amd.com/gpu"
_NVIDIA_GPU_RESOURCE: t.Literal["nvidia.com/gpu"] = "nvidia.com/gpu"
_CPU_RESOURCE: t.Literal["cpu"] = "cpu"
NvidiaGpuResource = _make_resource_class(
"NvidiaGpuResource",
_NVIDIA_GPU_RESOURCE,
"""NVIDIA GPU resource.
NvidiaGpuResource = _make_resource_class("NvidiaGpuResource", _NVIDIA_GPU_RESOURCE, """NVIDIA GPU resource.
This is a modified version of internal's BentoML's NvidiaGpuResource
where it respects and parse CUDA_VISIBLE_DEVICES correctly.""",
)
AmdGpuResource = _make_resource_class(
"AmdGpuResource",
_AMD_GPU_RESOURCE,
"""AMD GPU resource.
where it respects and parse CUDA_VISIBLE_DEVICES correctly.""",)
AmdGpuResource = _make_resource_class("AmdGpuResource", _AMD_GPU_RESOURCE, """AMD GPU resource.
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""",
)
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""",)
LiteralResourceSpec = t.Literal["cloud-tpus.google.com/v2", "amd.com/gpu", "nvidia.com/gpu", "cpu"]
# convenient mapping
def resource_spec(name: t.Literal["tpu", "amd", "nvidia", "cpu"]) -> LiteralResourceSpec:
if name == "tpu": return _TPU_RESOURCE
elif name == "amd": return _AMD_GPU_RESOURCE
elif name == "nvidia": return _NVIDIA_GPU_RESOURCE
elif name == "cpu": return _CPU_RESOURCE
else: raise ValueError("Unknown alias. Accepted: ['tpu', 'amd', 'nvidia', 'cpu']")
def resource_spec(name: t.Literal["tpu", "amd", "nvidia", "cpu"]) -> LiteralResourceSpec:
if name == "tpu": return _TPU_RESOURCE
elif name == "amd": return _AMD_GPU_RESOURCE
elif name == "nvidia": return _NVIDIA_GPU_RESOURCE
elif name == "cpu": return _CPU_RESOURCE
else: raise ValueError("Unknown alias. Accepted: ['tpu', 'amd', 'nvidia', 'cpu']")
@functools.lru_cache
def available_resource_spec() -> tuple[LiteralResourceSpec, ...]:
"""This is a utility function helps to determine the available resources from given running system.
"""This is a utility function helps to determine the available resources from given running system.
It will first check for TPUs -> AMD GPUS -> NVIDIA GPUS -> CPUs.
TODO: Supports TPUs
"""
available = ()
if len(AmdGpuResource.from_system()) > 0: available += (_AMD_GPU_RESOURCE,)
if len(NvidiaGpuResource.from_system()) > 0: available += (_NVIDIA_GPU_RESOURCE,)
available += (_CPU_RESOURCE,)
return t.cast(t.Tuple[LiteralResourceSpec, ...], available)
It will first check for TPUs -> AMD GPUS -> NVIDIA GPUS -> CPUs.
TODO: Supports TPUs
"""
available = ()
if len(AmdGpuResource.from_system()) > 0: available += (_AMD_GPU_RESOURCE,)
if len(NvidiaGpuResource.from_system()) > 0: available += (_NVIDIA_GPU_RESOURCE,)
available += (_CPU_RESOURCE,)
return t.cast(t.Tuple[LiteralResourceSpec, ...], available)
class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
"""This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
"""This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU.
See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices
for ROCm's support for CUDA_VISIBLE_DEVICES.
It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU.
See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices
for ROCm's support for CUDA_VISIBLE_DEVICES.
TODO: Support CloudTPUResource
TODO: Support CloudTPUResource
"""
@classmethod
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float) -> int:
if resource_request is None: resource_request = system_resources()
def _get_gpu_count(typ: list[str] | None, kind: str) -> int | None:
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: return math.ceil(len(typ) * workers_per_resource)
# use NVIDIA
kind = "nvidia.com/gpu"
count = _get_gpu_count(get_resource(resource_request, kind), kind)
if count: return count
# use AMD
kind = "amd.com/gpu"
count = _get_gpu_count(get_resource(resource_request, kind, validate=False), kind)
if count: return count
# use CPU
cpus = get_resource(resource_request, "cpu")
if cpus is not None and cpus > 0:
if "cpu" not in runnable_class.SUPPORTED_RESOURCES: logger.warning("No known supported resource available for %s, falling back to using CPU.", runnable_class)
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError("Fractional CPU multi threading support is not yet supported.")
return int(workers_per_resource)
return math.ceil(cpus) * workers_per_resource
# this should not be reached by user since we always read system resource as default
raise ValueError(f"No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.")
@classmethod
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
"""Get worker env for this given worker_index.
Args:
runnable_class: The runnable class to be run.
resource_request: The resource request of the runnable.
workers_per_resource: # of workers per resource.
worker_index: The index of the worker, start from 0.
"""
@classmethod
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float) -> int:
if resource_request is None: resource_request = system_resources()
def _get_gpu_count(typ: list[str] | None, kind: str) -> int | None:
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: return math.ceil(len(typ) * workers_per_resource)
# use NVIDIA
kind = "nvidia.com/gpu"
count = _get_gpu_count(get_resource(resource_request, kind), kind)
if count: return count
# use AMD
kind = "amd.com/gpu"
count = _get_gpu_count(get_resource(resource_request, kind, validate=False), kind)
if count: return count
# use CPU
cpus = get_resource(resource_request, "cpu")
if cpus is not None and cpus > 0:
if "cpu" not in runnable_class.SUPPORTED_RESOURCES: logger.warning("No known supported resource available for %s, falling back to using CPU.", runnable_class)
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError("Fractional CPU multi threading support is not yet supported.")
return int(workers_per_resource)
return math.ceil(cpus) * workers_per_resource
# this should not be reached by user since we always read system resource as default
raise ValueError(f"No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.")
@classmethod
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
"""Get worker env for this given worker_index.
Args:
runnable_class: The runnable class to be run.
resource_request: The resource request of the runnable.
workers_per_resource: # of workers per resource.
worker_index: The index of the worker, start from 0.
"""
cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None)
disabled = cuda_env in ("", "-1")
environ: dict[str, t.Any] = {}
if resource_request is None: resource_request = system_resources()
# use NVIDIA
kind = "nvidia.com/gpu"
typ = get_resource(resource_request, kind)
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
if disabled:
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
return environ
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
logger.debug("Environ for worker %s: %s", worker_index, environ)
return environ
# use AMD
kind = "amd.com/gpu"
typ = get_resource(resource_request, kind, validate=False)
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
if disabled:
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
return environ
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
logger.debug("Environ for worker %s: %s", worker_index, environ)
return environ
# use CPU
cpus = get_resource(resource_request, "cpu")
if cpus is not None and cpus > 0:
environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
thread_count = math.ceil(cpus)
for thread_env in THREAD_ENVS: environ[thread_env] = os.getenv(thread_env, str(thread_count))
logger.debug("Environ for worker %s: %s", worker_index, environ)
return environ
for thread_env in THREAD_ENVS: environ[thread_env] = os.getenv(thread_env, "1")
return environ
cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None)
disabled = cuda_env in ("", "-1")
environ: dict[str, t.Any] = {}
if resource_request is None: resource_request = system_resources()
# use NVIDIA
kind = "nvidia.com/gpu"
typ = get_resource(resource_request, kind)
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
if disabled:
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
return environ
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
logger.debug("Environ for worker %s: %s", worker_index, environ)
return environ
# use AMD
kind = "amd.com/gpu"
typ = get_resource(resource_request, kind, validate=False)
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
if disabled:
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
return environ
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
logger.debug("Environ for worker %s: %s", worker_index, environ)
return environ
# use CPU
cpus = get_resource(resource_request, "cpu")
if cpus is not None and cpus > 0:
environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
thread_count = math.ceil(cpus)
for thread_env in THREAD_ENVS:
environ[thread_env] = os.getenv(thread_env, str(thread_count))
logger.debug("Environ for worker %s: %s", worker_index, environ)
return environ
for thread_env in THREAD_ENVS:
environ[thread_env] = os.getenv(thread_env, "1")
return environ
return environ
@staticmethod
def transpile_workers_to_cuda_envvar(workers_per_resource: float | int, gpus: list[str], worker_index: int) -> str:
# Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.
if isinstance(workers_per_resource, float):
# NOTE: We hit this branch when workers_per_resource is set to
# float, for example 0.5 or 0.25
if workers_per_resource > 1:
raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.")
# We are round the assigned resource here. This means if workers_per_resource=.4
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
assigned_resource_per_worker = round(1 / workers_per_resource)
if len(gpus) < assigned_resource_per_worker:
logger.warning("Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])", gpus, worker_index, assigned_resource_per_worker)
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
assigned_gpu = gpus[assigned_resource_per_worker * worker_index : assigned_resource_per_worker * (worker_index + 1) ]
dev = ",".join(assigned_gpu)
else:
idx = worker_index // workers_per_resource
if idx >= len(gpus): raise ValueError(f"Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}")
dev = str(gpus[idx])
return dev
@staticmethod
def transpile_workers_to_cuda_envvar(workers_per_resource: float | int, gpus: list[str], worker_index: int) -> str:
# Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.
if isinstance(workers_per_resource, float):
# NOTE: We hit this branch when workers_per_resource is set to
# float, for example 0.5 or 0.25
if workers_per_resource > 1:
raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.")
# We are round the assigned resource here. This means if workers_per_resource=.4
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
assigned_resource_per_worker = round(1 / workers_per_resource)
if len(gpus) < assigned_resource_per_worker:
logger.warning("Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])", gpus, worker_index, assigned_resource_per_worker)
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
assigned_gpu = gpus[assigned_resource_per_worker * worker_index:assigned_resource_per_worker * (worker_index+1)]
dev = ",".join(assigned_gpu)
else:
idx = worker_index // workers_per_resource
if idx >= len(gpus): raise ValueError(f"Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}")
dev = str(gpus[idx])
return dev

View File

@@ -19,10 +19,8 @@ It will raises a RuntimeError if this is imported eagerly.
from __future__ import annotations
import typing as t
if not t.TYPE_CHECKING: raise RuntimeError(f"{__name__} should not be imported during runtime")
import attr
import bentoml
@@ -31,17 +29,15 @@ from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict
from ._configuration import AdapterType
from ._configuration import LiteralRuntime as LiteralRuntime
if t.TYPE_CHECKING:
import peft
import openllm
from openllm._llm import M as _M
from openllm._llm import T as _T
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
from bentoml._internal.runner.strategy import Strategy
import peft
import openllm
from openllm._llm import M as _M
from openllm._llm import T as _T
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
from bentoml._internal.runner.strategy import Strategy
AnyCallable = t.Callable[..., t.Any]
DictStrAny = dict[str, t.Any]
@@ -54,72 +50,72 @@ T = t.TypeVar("T")
Ts = t.TypeVarTuple("Ts")
At = t.TypeVar("At", bound=attr.AttrsInstance)
class PeftAdapterOutput(t.TypedDict):
success: bool
result: dict[str, peft.PeftConfig]
error_msg: str
success: bool
result: dict[str, peft.PeftConfig]
error_msg: str
class LLMEmbeddings(t.TypedDict):
embeddings: t.List[t.List[float]]
num_tokens: int
embeddings: t.List[t.List[float]]
num_tokens: int
class AdaptersTuple(TupleAny):
adapter_id: str
name: str | None
config: DictStrAny
adapter_id: str
name: str | None
config: DictStrAny
AdaptersMapping = dict[AdapterType, tuple[AdaptersTuple, ...]]
class LLMRunnable(bentoml.Runnable, t.Generic[_M, _T]):
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
__call__: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
set_adapter: RunnableMethod[LLMRunnable[_M, _T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
embeddings: RunnableMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
generate: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
generate_one: RunnableMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
generate_iterator: RunnableMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
__call__: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
set_adapter: RunnableMethod[LLMRunnable[_M, _T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
embeddings: RunnableMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
generate: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
generate_one: RunnableMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
generate_iterator: RunnableMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
class LLMRunner(bentoml.Runner, t.Generic[_M, _T]):
__doc__: str
__module__: str
llm_type: str
identifying_params: dict[str, t.Any]
llm: openllm.LLM[_M, _T]
config: openllm.LLMConfig
implementation: LiteralRuntime
supports_embeddings: bool
supports_hf_agent: bool
has_adapters: bool
embeddings: RunnerMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
generate: RunnerMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
generate_one: RunnerMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
generate_iterator: RunnerMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
def __init__(
self,
runnable_class: type[LLMRunnable[_M, _T]],
*,
runnable_init_params: dict[str, t.Any] | None = ...,
name: str | None = ...,
scheduling_strategy: type[Strategy] = ...,
models: list[bentoml.Model] | None = ...,
max_batch_size: int | None = ...,
max_latency_ms: int | None = ...,
method_configs: dict[str, dict[str, int]] | None = ...,
embedded: bool = False,
) -> None: ...
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ...
def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ...
def run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
def download_model(self) -> bentoml.Model: ...
@property
def peft_adapters(self) -> PeftAdapterOutput: ...
@property
def __repr_keys__(self) -> set[str]: ...
__doc__: str
__module__: str
llm_type: str
identifying_params: dict[str, t.Any]
llm: openllm.LLM[_M, _T]
config: openllm.LLMConfig
implementation: LiteralRuntime
supports_embeddings: bool
supports_hf_agent: bool
has_adapters: bool
embeddings: RunnerMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
generate: RunnerMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
generate_one: RunnerMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
generate_iterator: RunnerMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
def __init__(
self, runnable_class: type[LLMRunnable[_M, _T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False,
) -> None:
...
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any:
...
def embed(self, prompt: str | list[str]) -> LLMEmbeddings:
...
def run(self, prompt: str, **attrs: t.Any) -> t.Any:
...
async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any:
...
def download_model(self) -> bentoml.Model:
...
@property
def peft_adapters(self) -> PeftAdapterOutput:
...
@property
def __repr_keys__(self) -> set[str]:
...

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build-related utilities. Some of these utilities are mainly used for 'openllm.build'.
These utilities will stay internal, and its API can be changed or updated without backward-compatibility.
@@ -23,20 +22,18 @@ import typing as t
from . import oci as oci
from ..utils import LazyModule
_import_structure: dict[str, list[str]] = {
"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"],
"oci": oci.__all__,
}
_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": oci.__all__,}
if t.TYPE_CHECKING:
from . import _package as _package
from ._package import build_editable as build_editable
from ._package import construct_docker_options as construct_docker_options
from ._package import construct_python_options as construct_python_options
from ._package import create_bento as create_bento
from .oci import CONTAINER_NAMES as CONTAINER_NAMES
from .oci import build_container as build_container
from .oci import get_base_container_name as get_base_container_name
from .oci import get_base_container_tag as get_base_container_tag
from .oci import supported_registries as supported_registries
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from . import _package as _package
from ._package import build_editable as build_editable
from ._package import construct_docker_options as construct_docker_options
from ._package import construct_python_options as construct_python_options
from ._package import create_bento as create_bento
from .oci import CONTAINER_NAMES as CONTAINER_NAMES
from .oci import build_container as build_container
from .oci import get_base_container_name as get_base_container_name
from .oci import get_base_container_tag as get_base_container_tag
from .oci import supported_registries as supported_registries
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -45,200 +45,149 @@ from ..utils import is_torch_available
from ..utils import pkg
if t.TYPE_CHECKING:
from fs.base import FS
from fs.base import FS
import openllm
from bentoml._internal.bento import BentoStore
from bentoml._internal.models.model import ModelStore
import openllm
from bentoml._internal.bento import BentoStore
from bentoml._internal.models.model import ModelStore
from .oci import LiteralContainerRegistry
from .oci import LiteralContainerVersionStrategy
from .oci import LiteralContainerRegistry
from .oci import LiteralContainerVersionStrategy
logger = logging.getLogger(__name__)
OPENLLM_DEV_BUILD = "OPENLLM_DEV_BUILD"
def build_editable(path: str) -> str | None:
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != "true": return None
# We need to build the package in editable mode, so that we can import it
from build import ProjectBuilder
from build.env import IsolatedEnvBuilder
module_location = pkg.source_locations("openllm")
if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.")
pyproject_path = Path(module_location).parent.parent / "pyproject.toml"
if os.path.isfile(pyproject_path.__fspath__()):
logger.info("OpenLLM is installed in editable mode. Generating built wheels...")
with IsolatedEnvBuilder() as env:
builder = ProjectBuilder(pyproject_path.parent)
builder.python_executable = env.executable
builder.scripts_dir = env.scripts_dir
env.install(builder.build_system_requires)
return builder.build("wheel", path, config_settings={"--global-option": "--quiet"})
raise RuntimeError("Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.")
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != "true": return None
# We need to build the package in editable mode, so that we can import it
from build import ProjectBuilder
from build.env import IsolatedEnvBuilder
module_location = pkg.source_locations("openllm")
if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.")
pyproject_path = Path(module_location).parent.parent / "pyproject.toml"
if os.path.isfile(pyproject_path.__fspath__()):
logger.info("OpenLLM is installed in editable mode. Generating built wheels...")
with IsolatedEnvBuilder() as env:
builder = ProjectBuilder(pyproject_path.parent)
builder.python_executable = env.executable
builder.scripts_dir = env.scripts_dir
env.install(builder.build_system_requires)
return builder.build("wheel", path, config_settings={"--global-option": "--quiet"})
raise RuntimeError("Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.")
def construct_python_options(
llm: openllm.LLM[t.Any, t.Any],
llm_fs: FS,
extra_dependencies: tuple[str, ...] | None = None,
adapter_map: dict[str, str | None] | None = None,
) -> PythonOptions:
packages = ["openllm", "scipy"] # apparently bnb misses this one
if adapter_map is not None: packages += ["openllm[fine-tune]"]
# NOTE: add openllm to the default dependencies
# if users has openllm custom built wheels, it will still respect
# that since bentoml will always install dependencies from requirements.txt
# first, then proceed to install everything inside the wheels/ folder.
if extra_dependencies is not None: packages += [f"openllm[{k}]" for k in extra_dependencies]
def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str | None] | None = None,) -> PythonOptions:
packages = ["openllm", "scipy"] # apparently bnb misses this one
if adapter_map is not None: packages += ["openllm[fine-tune]"]
# NOTE: add openllm to the default dependencies
# if users has openllm custom built wheels, it will still respect
# that since bentoml will always install dependencies from requirements.txt
# first, then proceed to install everything inside the wheels/ folder.
if extra_dependencies is not None: packages += [f"openllm[{k}]" for k in extra_dependencies]
req = llm.config["requirements"]
if req is not None: packages.extend(req)
if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
req = llm.config["requirements"]
if req is not None: packages.extend(req)
if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
env = llm.config["env"]
framework_envvar = env["framework_value"]
if framework_envvar == "flax":
if not is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
packages.extend([importlib.metadata.version("flax"), importlib.metadata.version("jax"), importlib.metadata.version("jaxlib")])
elif framework_envvar == "tf":
if not is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
)
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for candidate in candidates:
try:
pkgver = importlib.metadata.version(candidate)
if pkgver == candidate: packages.extend(["tensorflow"])
else:
_tf_version = importlib.metadata.version(candidate)
packages.extend([f"tensorflow>={_tf_version}"])
break
except importlib.metadata.PackageNotFoundError: pass
else:
if not is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
env = llm.config["env"]
framework_envvar = env["framework_value"]
if framework_envvar == "flax":
if not is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
packages.extend([importlib.metadata.version("flax"), importlib.metadata.version("jax"), importlib.metadata.version("jaxlib")])
elif framework_envvar == "tf":
if not is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos",)
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for candidate in candidates:
try:
pkgver = importlib.metadata.version(candidate)
if pkgver == candidate: packages.extend(["tensorflow"])
else:
_tf_version = importlib.metadata.version(candidate)
packages.extend([f"tensorflow>={_tf_version}"])
break
except importlib.metadata.PackageNotFoundError:
pass
else:
if not is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
wheels: list[str] = []
built_wheels = build_editable(llm_fs.getsyspath("/"))
if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}"))
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"])
wheels: list[str] = []
built_wheels = build_editable(llm_fs.getsyspath("/"))
if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}"))
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"])
def construct_docker_options(
llm: openllm.LLM[t.Any, t.Any],
_: FS,
workers_per_resource: int | float,
quantize: t.LiteralString | None,
bettertransformer: bool | None,
adapter_map: dict[str, str | None] | None,
dockerfile_template: str | None,
runtime: t.Literal["ggml", "transformers"],
serialisation_format: t.Literal["safetensors", "legacy"],
container_registry: LiteralContainerRegistry,
llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry,
container_version_strategy: LiteralContainerVersionStrategy,
) -> DockerOptions:
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
_bentoml_config_options_opts = [
"api_server.traffic.timeout=36000", # NOTE: Currently we hardcode this value
f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}',
f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
]
_bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts)
env: EnvVarMixin = llm.config["env"]
env_dict = {
env.framework: env.framework_value,
env.config: f"'{llm.config.model_dump_json().decode()}'",
"OPENLLM_MODEL": llm.config["model_name"],
"OPENLLM_SERIALIZATION": serialisation_format,
"OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'",
"OPENLLM_FAST": str(True),
"BENTOML_DEBUG": str(True),
"BENTOML_QUIET": str(False),
"OPENLLMDEVDEBUG": str(get_debug_mode()),
"BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'",
env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}", # This is the default BENTO_PATH var
}
if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
_bentoml_config_options_opts = [
"api_server.traffic.timeout=36000", # NOTE: Currently we hardcode this value
f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
]
_bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts)
env: EnvVarMixin = llm.config["env"]
env_dict = {
env.framework: env.framework_value, env.config: f"'{llm.config.model_dump_json().decode()}'", "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format, "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "OPENLLM_FAST": str(True), "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "OPENLLMDEVDEBUG": str(get_debug_mode()),
"BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}", # This is the default BENTO_PATH var
}
if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
# We need to handle None separately here, as env from subprocess doesn't accept None value.
_env = EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
# We need to handle None separately here, as env from subprocess doesn't accept None value.
_env = EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
if _env.bettertransformer_value is not None: env_dict[_env.bettertransformer] = str(_env.bettertransformer_value)
if _env.quantize_value is not None: env_dict[_env.quantize] = _env.quantize_value
env_dict[_env.runtime] = _env.runtime_value
return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}",env=env_dict, dockerfile_template=dockerfile_template)
if _env.bettertransformer_value is not None: env_dict[_env.bettertransformer] = str(_env.bettertransformer_value)
if _env.quantize_value is not None: env_dict[_env.quantize] = _env.quantize_value
env_dict[_env.runtime] = _env.runtime_value
return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}", env=env_dict, dockerfile_template=dockerfile_template)
@inject
def create_bento(
bento_tag: bentoml.Tag,
llm_fs: FS,
llm: openllm.LLM[t.Any, t.Any],
workers_per_resource: str | int | float,
quantize: t.LiteralString | None,
bettertransformer: bool | None,
dockerfile_template: str | None,
adapter_map: dict[str, str | None] | None = None,
extra_dependencies: tuple[str, ...] | None = None,
runtime: t.Literal["ggml", "transformers"] = "transformers",
serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors",
container_registry: LiteralContainerRegistry = "ecr",
container_version_strategy: LiteralContainerVersionStrategy = "release",
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers",
serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release", _bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store],
) -> bentoml.Bento:
framework_envvar = llm.config["env"]["framework_value"]
labels = dict(llm.identifying_params)
labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"})
if adapter_map: labels.update(adapter_map)
if isinstance(workers_per_resource, str):
if workers_per_resource == "round_robin": workers_per_resource = 1.0
elif workers_per_resource == "conserved": workers_per_resource = 1.0 if device_count() == 0 else float(1 / device_count())
else:
try: workers_per_resource = float(workers_per_resource)
except ValueError: raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
elif isinstance(workers_per_resource, int): workers_per_resource = float(workers_per_resource)
framework_envvar = llm.config["env"]["framework_value"]
labels = dict(llm.identifying_params)
labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"})
if adapter_map: labels.update(adapter_map)
if isinstance(workers_per_resource, str):
if workers_per_resource == "round_robin": workers_per_resource = 1.0
elif workers_per_resource == "conserved": workers_per_resource = 1.0 if device_count() == 0 else float(1 / device_count())
else:
try:
workers_per_resource = float(workers_per_resource)
except ValueError:
raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
elif isinstance(workers_per_resource, int):
workers_per_resource = float(workers_per_resource)
logger.info("Building Bento for '%s'", llm.config["start_name"])
# add service.py definition to this temporary folder
codegen.write_service(llm, adapter_map, llm_fs)
logger.info("Building Bento for '%s'", llm.config["start_name"])
# add service.py definition to this temporary folder
codegen.write_service(llm, adapter_map, llm_fs)
llm_spec = ModelSpec.from_item({"tag": str(llm.tag), "alias": llm.tag.name})
build_config = BentoBuildConfig(
service=f"{llm.config['service_name']}:svc",
name=bento_tag.name,
labels=labels,
description=f"OpenLLM service for {llm.config['start_name']}",
include=list(llm_fs.walk.files()),
exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"],
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map, dockerfile_template, runtime, serialisation_format, container_registry, container_version_strategy),
models=[llm_spec],
)
llm_spec = ModelSpec.from_item({"tag": str(llm.tag), "alias": llm.tag.name})
build_config = BentoBuildConfig(
service=f"{llm.config['service_name']}:svc", name=bento_tag.name, labels=labels, description=f"OpenLLM service for {llm.config['start_name']}", include=list(llm_fs.walk.files()), exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"], python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map, dockerfile_template, runtime, serialisation_format, container_registry, container_version_strategy), models=[llm_spec],
)
bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath("/"))
# NOTE: the model_id_path here are only used for setting this environment variable within the container
# built with for BentoLLM.
service_fs_path = fs.path.join("src", llm.config["service_name"])
service_path = bento._fs.getsyspath(service_fs_path)
with open(service_path, "r") as f: service_contents = f.readlines()
bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath("/"))
# NOTE: the model_id_path here are only used for setting this environment variable within the container
# built with for BentoLLM.
service_fs_path = fs.path.join("src", llm.config["service_name"])
service_path = bento._fs.getsyspath(service_fs_path)
with open(service_path, "r") as f:
service_contents = f.readlines()
for it in service_contents:
if "__bento_name__" in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
for it in service_contents:
if "__bento_name__" in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
script = "".join(service_contents)
if DEBUG: logger.info("Generated script:\n%s", script)
script = "".join(service_contents)
if DEBUG: logger.info("Generated script:\n%s", script)
bento._fs.writetext(service_fs_path, script)
if "model_store" in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store)
# backward arguments. `model_store` is added recently
return bento.save(bento_store=_bento_store)
bento._fs.writetext(service_fs_path, script)
if "model_store" in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store)
# backward arguments. `model_store` is added recently
return bento.save(bento_store=_bento_store)

View File

@@ -42,8 +42,7 @@ LiteralContainerVersionStrategy = t.Literal["release", "nightly", "latest"]
# but in the future, we can infer based on git repo and everything to make it more options for users
# to build the base image. For now, all of the base image will be <registry>/bentoml/openllm:...
# NOTE: The ECR registry is the public one and currently only @bentoml team has access to push it.
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {"docker": "docker.io/bentoml/openllm", "gh": "ghcr.io/bentoml/openllm",
"ecr": "public.ecr.aws/y5w8i4y6/bentoml/openllm"}
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {"docker": "docker.io/bentoml/openllm", "gh": "ghcr.io/bentoml/openllm", "ecr": "public.ecr.aws/y5w8i4y6/bentoml/openllm"}
_URI = "https://github.com/bentoml/openllm.git"
@@ -51,56 +50,63 @@ _module_location = pkg.source_locations("openllm")
@functools.lru_cache
@apply(str.lower)
def get_base_container_name(reg: LiteralContainerRegistry) -> str: return _CONTAINER_REGISTRY[reg]
def get_base_container_name(reg: LiteralContainerRegistry) -> str:
return _CONTAINER_REGISTRY[reg]
@functools.lru_cache(maxsize=1)
def _git() -> git.cmd.Git: return git.cmd.Git(_URI)
def _git() -> git.cmd.Git:
return git.cmd.Git(_URI)
@functools.lru_cache
def _nightly_ref() -> tuple[str, str]: return _git().ls_remote(_URI, "main", heads=True).split()
def _nightly_ref() -> tuple[str, str]:
return _git().ls_remote(_URI, "main", heads=True).split()
@functools.lru_cache
def _stable_ref() -> tuple[str, str]: return max([item.split() for item in _git().ls_remote(_URI, refs=True, tags=True).split("\n")],
key = lambda tag: tuple(int(k) for k in tag[-1].replace("refs/tags/v", "").split(".")))
def _stable_ref() -> tuple[str, str]:
return max([item.split() for item in _git().ls_remote(_URI, refs=True, tags=True).split("\n")], key=lambda tag: tuple(int(k) for k in tag[-1].replace("refs/tags/v", "").split(".")))
def get_base_container_tag(strategy: LiteralContainerVersionStrategy) -> str:
if strategy == "release": return _stable_ref()[-1].replace("refs/tags/v", "") # for stable, we can also use latest, but discouraged
elif strategy == "latest": return "latest"
elif strategy == "nightly": return f"sha-{_nightly_ref()[0][:7]}" # we prefixed with sha-<git_rev_short> (giv_rev[:7])
else: raise ValueError(f"Unknown strategy '{strategy}'. Valid strategies are 'release', 'nightly', and 'latest'")
if strategy == "release": return _stable_ref()[-1].replace("refs/tags/v", "") # for stable, we can also use latest, but discouraged
elif strategy == "latest": return "latest"
elif strategy == "nightly": return f"sha-{_nightly_ref()[0][:7]}" # we prefixed with sha-<git_rev_short> (giv_rev[:7])
else: raise ValueError(f"Unknown strategy '{strategy}'. Valid strategies are 'release', 'nightly', and 'latest'")
def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None, version_strategy: LiteralContainerVersionStrategy = "release", push: bool = False, machine: bool = False) -> dict[str | LiteralContainerRegistry, str]:
"""This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed.
"""This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed.
Note that this is useful for debugging or for any users who wish to integrate vertically with OpenLLM. For most users, you should be able to get the image either from GitHub Container Registry or our public ECR registry.
"""
try:
if not _BUILDER.health(): raise Error
except (Error, subprocess.CalledProcessError): raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
if device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)")
if not shutil.which("nvidia-container-runtime"): raise RuntimeError("Make sure to have NVIDIA Container Toolkit setup correctly to compile CUDA kernel in container. See https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html for more information.")
if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml"
if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
tags: dict[str | LiteralContainerRegistry, str]
if not registries: tags = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy
else:
registries = [registries] if isinstance(registries, str) else list(registries)
tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries}
try:
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()),
push=push, progress="plain" if get_debug_mode() else "auto", quiet=machine)
if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip()
except Exception as err: raise OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
return tags
Note that this is useful for debugging or for any users who wish to integrate vertically with OpenLLM. For most users, you should be able to get the image either from GitHub Container Registry or our public ECR registry.
"""
try:
if not _BUILDER.health(): raise Error
except (Error, subprocess.CalledProcessError):
raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
if device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)")
if not shutil.which("nvidia-container-runtime"): raise RuntimeError("NVIDIA Container Toolkit is required to compile CUDA kernel in container.")
if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml"
if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
if t.TYPE_CHECKING: tags: dict[str | LiteralContainerRegistry, str]
if not registries: tags = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy
else:
registries = [registries] if isinstance(registries, str) else list(registries)
tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries}
try:
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()), push=push, progress="plain" if get_debug_mode() else "auto", quiet=machine)
if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip()
except Exception as err:
raise OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
return tags
if t.TYPE_CHECKING:
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
supported_registries: list[str]
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
supported_registries: list[str]
__all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries"]
def __dir__() -> list[str]: return sorted(__all__)
def __dir__() -> list[str]:
return sorted(__all__)
def __getattr__(name: str) -> t.Any:
if name == "supported_registries": return functools.lru_cache(1)(lambda _: list(_CONTAINER_REGISTRY))()
elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY
elif name in __all__: return importlib.import_module("." + name, __name__)
else: raise AttributeError(f"{name} does not exists under {__name__}")
if name == "supported_registries": return functools.lru_cache(1)(lambda _: list(_CONTAINER_REGISTRY))()
elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY
elif name in __all__: return importlib.import_module("." + name, __name__)
else: raise AttributeError(f"{name} does not exists under {__name__}")

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OpenLLM CLI.
For more information see ``openllm -h``.

View File

@@ -47,71 +47,64 @@ from ..utils import is_peft_available
from ..utils import resolve_user_filepath
if t.TYPE_CHECKING:
import subprocess
import subprocess
from .._configuration import LLMConfig
from .._types import DictStrAny
from .._types import P
TupleStr = tuple[str, ...]
else: TupleStr = tuple
from .._configuration import LLMConfig
from .._types import DictStrAny
from .._types import P
TupleStr = tuple[str, ...]
else:
TupleStr = tuple
LiteralOutput = t.Literal["json", "pretty", "porcelain"]
_AnyCallable = t.Callable[..., t.Any]
FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, click.Command])
def parse_config_options(
config: LLMConfig,
server_timeout: int,
workers_per_resource: float,
device: tuple[str, ...] | None,
environ: DictStrAny,
) -> DictStrAny:
# TODO: Support amd.com/gpu on k8s
_bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "")
_bentoml_config_options_opts = ["tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}']
if device:
if len(device) > 1: _bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
else: _bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
_bentoml_config_options_env += " " if _bentoml_config_options_env else "" + " ".join(_bentoml_config_options_opts)
environ["BENTOML_CONFIG_OPTIONS"] = _bentoml_config_options_env
return environ
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: tuple[str, ...] | None, environ: DictStrAny,) -> DictStrAny:
# TODO: Support amd.com/gpu on k8s
_bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "")
_bentoml_config_options_opts = ["tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}']
if device:
if len(device) > 1: _bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
else: _bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
_bentoml_config_options_env += " " if _bentoml_config_options_env else "" + " ".join(_bentoml_config_options_opts)
environ["BENTOML_CONFIG_OPTIONS"] = _bentoml_config_options_env
return environ
_adapter_mapping_key = "adapter_map"
def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> None:
if not value: return None
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
for v in value:
adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
# try to resolve the full path if users pass in relative,
# currently only support one level of resolve path with current directory
try: adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
except FileNotFoundError: pass
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
return None
def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> None:
if not value: return None
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
for v in value:
adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
# try to resolve the full path if users pass in relative,
# currently only support one level of resolve path with current directory
try:
adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
except FileNotFoundError:
pass
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
return None
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
"""Generate a 'click.Command' for any given LLM.
"""Generate a 'click.Command' for any given LLM.
Args:
group: the target ``click.Group`` to save this LLM cli under
model: The name of the model or the ``bentoml.Bento`` instance.
Args:
group: the target ``click.Group`` to save this LLM cli under
model: The name of the model or the ``bentoml.Bento`` instance.
Returns:
The click.Command for starting the model server
Returns:
The click.Command for starting the model server
Note that the internal commands will return the llm_config and a boolean determine
whether the server is run with GPU or not.
"""
llm_config = AutoConfig.for_model(model)
Note that the internal commands will return the llm_config and a boolean determine
whether the server is run with GPU or not.
"""
llm_config = AutoConfig.for_model(model)
command_attrs: DictStrAny = dict(
name=llm_config["model_name"],
context_settings=_context_settings or termui.CONTEXT_SETTINGS,
short_help=f"Start a LLMServer for '{model}'",
aliases=[llm_config["start_name"]] if llm_config["name_type"] == "dasherize" else None,
help=f"""\
command_attrs: DictStrAny = dict(
name=llm_config["model_name"], context_settings=_context_settings or termui.CONTEXT_SETTINGS, short_help=f"Start a LLMServer for '{model}'", aliases=[llm_config["start_name"]] if llm_config["name_type"] == "dasherize" else None, help=f"""\
{llm_config['env'].start_docstring}
\b
@@ -130,142 +123,125 @@ Available official model_id(s): [default: {llm_config['default_id']}]
\b
{orjson.dumps(llm_config['model_ids'], option=orjson.OPT_INDENT_2).decode()}
""",
)
)
if llm_config["requires_gpu"] and device_count() < 1:
# NOTE: The model requires GPU, therefore we will return a dummy command
command_attrs.update({"short_help": "(Disabled because there is no GPU available)", "help": f"""{model} is currently not available to run on your local machine because it requires GPU for inference."""})
return noop_command(group, llm_config, _serve_grpc, **command_attrs)
if llm_config["requires_gpu"] and device_count() < 1:
# NOTE: The model requires GPU, therefore we will return a dummy command
command_attrs.update({"short_help": "(Disabled because there is no GPU available)", "help": f"""{model} is currently not available to run on your local machine because it requires GPU for inference."""})
return noop_command(group, llm_config, _serve_grpc, **command_attrs)
@group.command(**command_attrs)
@start_decorator(llm_config, serve_grpc=_serve_grpc)
@click.pass_context
def start_cmd(
ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, workers_per_resource: t.Literal["conserved", "round_robin"] | t.LiteralString, device: tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool,
serialisation_format: t.Literal["safetensors", "legacy"], adapter_id: str | None, return_process: bool, **attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
fast = str(fast).upper() in ENV_VARS_TRUE_VALUES
if serialisation_format == "safetensors" and quantize is not None and os.getenv("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in ENV_VARS_TRUE_VALUES:
termui.echo(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.", fg="yellow")
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
config, server_attrs = llm_config.model_validate_click(**attrs)
server_timeout = first_not_none(server_timeout, default=config["timeout"])
server_attrs.update({"working_dir": os.path.dirname(os.path.dirname(__file__)), "timeout": server_timeout})
if _serve_grpc: server_attrs["grpc_protocol_version"] = "v1"
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop("development")
server_attrs.setdefault("production", not development)
wpr = first_not_none(workers_per_resource, default=config["workers_per_resource"])
@group.command(**command_attrs)
@start_decorator(llm_config, serve_grpc=_serve_grpc)
@click.pass_context
def start_cmd(
ctx: click.Context, /,
server_timeout: int,
model_id: str | None,
model_version: str | None,
workers_per_resource: t.Literal["conserved", "round_robin"] | t.LiteralString,
device: tuple[str, ...],
quantize: t.Literal["int8", "int4", "gptq"] | None,
bettertransformer: bool | None,
runtime: t.Literal["ggml", "transformers"],
fast: bool,
serialisation_format: t.Literal["safetensors", "legacy"],
adapter_id: str | None,
return_process: bool,
**attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
fast = str(fast).upper() in ENV_VARS_TRUE_VALUES
if serialisation_format == "safetensors" and quantize is not None and os.getenv("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in ENV_VARS_TRUE_VALUES:
termui.echo(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.", fg="yellow")
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
config, server_attrs = llm_config.model_validate_click(**attrs)
server_timeout = first_not_none(server_timeout, default=config["timeout"])
server_attrs.update({"working_dir": os.path.dirname(os.path.dirname(__file__)), "timeout": server_timeout})
if _serve_grpc: server_attrs["grpc_protocol_version"] = "v1"
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop("development")
server_attrs.setdefault("production", not development)
wpr = first_not_none(workers_per_resource, default=config["workers_per_resource"])
if isinstance(wpr, str):
if wpr == "round_robin": wpr = 1.0
elif wpr == "conserved":
if device and device_count() == 0:
termui.echo("--device will have no effect as there is no GPUs available", fg="yellow")
wpr = 1.0
else:
available_gpu = len(device) if device else device_count()
wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu)
else: wpr = float(wpr)
elif isinstance(wpr, int): wpr = float(wpr)
# Create a new model env to work with the envvar during CLI invocation
env = EnvVarMixin(config["model_name"], config.default_implementation(), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
prerequisite_check(ctx, config, quantize, adapter_map, int(1 / wpr))
# NOTE: This is to set current configuration
start_env = os.environ.copy()
start_env = parse_config_options(config, server_timeout, wpr, device, start_env)
if fast: termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg="yellow")
start_env.update(
{
"OPENLLM_MODEL": model,
"BENTOML_DEBUG": str(get_debug_mode()),
"BENTOML_HOME": os.getenv("BENTOML_HOME", BentoMLContainer.bentoml_home.get()),
"OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(),
"OPENLLM_SERIALIZATION": serialisation_format,
env.runtime: env.runtime_value,
env.framework: env.framework_value,
}
)
if env.model_id_value: start_env[env.model_id] = str(env.model_id_value)
# NOTE: quantize and bettertransformer value is already assigned within env
if bettertransformer is not None: start_env[env.bettertransformer] = str(env.bettertransformer_value)
if quantize is not None: start_env[env.quantize] = str(env.quantize_value)
llm = infer_auto_class(env.framework_value).for_model(model, model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format)
start_env.update({env.config: llm.config.model_dump_json().decode(), env.model_id: llm.model_id})
server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs)
analytics.track_start_init(llm.config)
def next_step(model_name: str, adapter_map: DictStrAny | None) -> None:
cmd_name = f"openllm build {model_name}"
if adapter_map is not None: cmd_name += " " + " ".join([f"--adapter-id {s}" for s in [f"{p}:{name}" if name not in (None, "default") else p for p, name in adapter_map.items()]])
if not get_quiet_mode(): termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg="blue")
if return_process:
server.start(env=start_env, text=True)
if server.process is None: raise click.ClickException("Failed to start the server.")
return server.process
if isinstance(wpr, str):
if wpr == "round_robin": wpr = 1.0
elif wpr == "conserved":
if device and device_count() == 0:
termui.echo("--device will have no effect as there is no GPUs available", fg="yellow")
wpr = 1.0
else:
try: server.start(env=start_env, text=True, blocking=True)
except KeyboardInterrupt: next_step(model, adapter_map)
except Exception as err: termui.echo(f"Error caught while running LLM Server:\n{err}", fg="red")
else: next_step(model, adapter_map)
available_gpu = len(device) if device else device_count()
wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu)
else:
wpr = float(wpr)
elif isinstance(wpr, int):
wpr = float(wpr)
# NOTE: Return the configuration for telemetry purposes.
return config
# Create a new model env to work with the envvar during CLI invocation
env = EnvVarMixin(config["model_name"], config.default_implementation(), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
prerequisite_check(ctx, config, quantize, adapter_map, int(1 / wpr))
return start_cmd
# NOTE: This is to set current configuration
start_env = os.environ.copy()
start_env = parse_config_options(config, server_timeout, wpr, device, start_env)
if fast: termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg="yellow")
start_env.update({"OPENLLM_MODEL": model, "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_HOME": os.getenv("BENTOML_HOME", BentoMLContainer.bentoml_home.get()), "OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(), "OPENLLM_SERIALIZATION": serialisation_format, env.runtime: env.runtime_value, env.framework: env.framework_value,})
if env.model_id_value: start_env[env.model_id] = str(env.model_id_value)
# NOTE: quantize and bettertransformer value is already assigned within env
if bettertransformer is not None: start_env[env.bettertransformer] = str(env.bettertransformer_value)
if quantize is not None: start_env[env.quantize] = str(env.quantize_value)
llm = infer_auto_class(env.framework_value).for_model(model, model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format)
start_env.update({env.config: llm.config.model_dump_json().decode(), env.model_id: llm.model_id})
server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs)
analytics.track_start_init(llm.config)
def next_step(model_name: str, adapter_map: DictStrAny | None) -> None:
cmd_name = f"openllm build {model_name}"
if adapter_map is not None: cmd_name += " " + " ".join([f"--adapter-id {s}" for s in [f"{p}:{name}" if name not in (None, "default") else p for p, name in adapter_map.items()]])
if not get_quiet_mode(): termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg="blue")
if return_process:
server.start(env=start_env, text=True)
if server.process is None: raise click.ClickException("Failed to start the server.")
return server.process
else:
try:
server.start(env=start_env, text=True, blocking=True)
except KeyboardInterrupt:
next_step(model, adapter_map)
except Exception as err:
termui.echo(f"Error caught while running LLM Server:\n{err}", fg="red")
else:
next_step(model, adapter_map)
# NOTE: Return the configuration for telemetry purposes.
return config
return start_cmd
def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, **command_attrs: t.Any) -> click.Command:
context_settings = command_attrs.pop("context_settings", {})
context_settings.update({"ignore_unknown_options": True, "allow_extra_args": True})
command_attrs["context_settings"] = context_settings
# NOTE: The model requires GPU, therefore we will return a dummy command
@group.command(**command_attrs)
def noop(**_: t.Any) -> LLMConfig:
termui.echo("No GPU available, therefore this command is disabled", fg="red")
analytics.track_start_init(llm_config)
return llm_config
return noop
context_settings = command_attrs.pop("context_settings", {})
context_settings.update({"ignore_unknown_options": True, "allow_extra_args": True})
command_attrs["context_settings"] = context_settings
# NOTE: The model requires GPU, therefore we will return a dummy command
@group.command(**command_attrs)
def noop(**_: t.Any) -> LLMConfig:
termui.echo("No GPU available, therefore this command is disabled", fg="red")
analytics.track_start_init(llm_config)
return llm_config
return noop
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: t.LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
if adapter_map and not is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
if quantize and llm_config.default_implementation() == "vllm": ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config.env['framework']}=\"pt\"' to run with quantization.")
requirements = llm_config["requirements"]
if requirements is not None and len(requirements) > 0:
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow")
if adapter_map and not is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
if quantize and llm_config.default_implementation() == "vllm": ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config.env['framework']}=\"pt\"' to run with quantization.")
requirements = llm_config["requirements"]
if requirements is not None and len(requirements) > 0:
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow")
def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
return lambda fn: compose(*[
llm_config.to_click_options,
_http_server_args if not serve_grpc else _grpc_server_args,
cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
model_id_option(factory=cog.optgroup, model_env=llm_config["env"]),
model_version_option(factory=cog.optgroup),
cog.optgroup.option("--server-timeout", type=int, default=None, help="Server timeout in seconds"),
workers_per_resource_option(factory=cog.optgroup),
fast_option(factory=cog.optgroup),
cog.optgroup.group(
"LLM Optimization Options",
help="""Optimization related options.
return lambda fn: compose(
*[
llm_config.to_click_options, _http_server_args if not serve_grpc else _grpc_server_args,
cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
model_id_option(factory=cog.optgroup, model_env=llm_config["env"]),
model_version_option(factory=cog.optgroup),
cog.optgroup.option("--server-timeout", type=int, default=None, help="Server timeout in seconds"),
workers_per_resource_option(factory=cog.optgroup),
fast_option(factory=cog.optgroup),
cog.optgroup.group(
"LLM Optimization Options", help="""Optimization related options.
OpenLLM supports running model with [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/),
k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
@@ -275,15 +251,14 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
""",
),
cog.optgroup.option("--device", type=dantic.CUDA, multiple=True, envvar="CUDA_VISIBLE_DEVICES", callback=parse_device_callback, help=f"Assign GPU devices (if available) for {llm_config['model_name']}.", show_envvar=True),
cog.optgroup.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers."),
quantize_option(factory=cog.optgroup, model_env=llm_config["env"]),
bettertransformer_option(factory=cog.optgroup, model_env=llm_config["env"]),
serialisation_option(factory=cog.optgroup),
cog.optgroup.group(
"Fine-tuning related options",
help="""\
),
cog.optgroup.option("--device", type=dantic.CUDA, multiple=True, envvar="CUDA_VISIBLE_DEVICES", callback=parse_device_callback, help=f"Assign GPU devices (if available) for {llm_config['model_name']}.", show_envvar=True),
cog.optgroup.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers."),
quantize_option(factory=cog.optgroup, model_env=llm_config["env"]),
bettertransformer_option(factory=cog.optgroup, model_env=llm_config["env"]),
serialisation_option(factory=cog.optgroup),
cog.optgroup.group(
"Fine-tuning related options", help="""\
Note that the argument `--adapter-id` can accept the following format:
- `--adapter-id /path/to/adapter` (local adapter)
@@ -298,18 +273,19 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
```
""",
),
cog.optgroup.option("--adapter-id", default=None, help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'", multiple=True, callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"),
click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True),
])(fn)
),
cog.optgroup.option("--adapter-id", default=None, help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'", multiple=True, callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"),
click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True),
]
)(fn)
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> TupleStr | None:
if value is None: return value
if not LazyType(TupleStr).isinstance(value): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})")
el: TupleStr = tuple(i for k in value for i in k)
# NOTE: --device all is a special case
if len(el) == 1 and el[0] == "all": return tuple(map(str, available_devices()))
return el
if value is None: return value
if not LazyType(TupleStr).isinstance(value): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})")
el: TupleStr = tuple(i for k in value for i in k)
# NOTE: --device all is a special case
if len(el) == 1 and el[0] == "all": return tuple(map(str, available_devices()))
return el
# NOTE: A list of bentoml option that is not needed for parsing.
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
@@ -317,67 +293,78 @@ def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tup
_IGNORED_OPTIONS = {"working_dir", "production", "protocol_version"}
def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`."""
from bentoml_cli.cli import cli
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`."""
from bentoml_cli.cli import cli
command = "serve" if not serve_grpc else "serve-grpc"
group = cog.optgroup.group(
f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options",
help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",
)
command = "serve" if not serve_grpc else "serve-grpc"
group = cog.optgroup.group(f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options", help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",)
def decorator(f: t.Callable[t.Concatenate[int, str | None, P], LLMConfig]) -> t.Callable[[FC], FC]:
serve_command = cli.commands[command]
# The first variable is the argument bento
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
serve_options = [p for p in serve_command.params[1 :-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
for options in reversed(serve_options):
attrs = options.to_info_dict()
# we don't need param_type_name, since it should all be options
attrs.pop("param_type_name")
# name is not a valid args
attrs.pop("name")
# type can be determine from default value
attrs.pop("type")
param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts"))
f = cog.optgroup.option(*param_decls, **attrs)(f)
return group(f)
return decorator
def decorator(f: t.Callable[t.Concatenate[int, str | None, P], LLMConfig]) -> t.Callable[[FC], FC]:
serve_command = cli.commands[command]
# The first variable is the argument bento
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
serve_options = [p for p in serve_command.params[1:-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
for options in reversed(serve_options):
attrs = options.to_info_dict()
# we don't need param_type_name, since it should all be options
attrs.pop("param_type_name")
# name is not a valid args
attrs.pop("name")
# type can be determine from default value
attrs.pop("type")
param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts"))
f = cog.optgroup.option(*param_decls, **attrs)(f)
return group(f)
return decorator
_http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True)
def cli_option(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
"""General ``@click.option`` with some sauce.
"""General ``@click.option`` with some sauce.
This decorator extends the default ``@click.option`` plus a factory option to use which type of option, for example: [click, click_option_group.optgroup]
"""
attrs.setdefault("help", "General option for OpenLLM CLI.")
factory = attrs.pop("factory", click)
def decorator(f: FC | None) -> FC: return t.cast(FC, factory.option(*param_decls, **attrs)(f) if f is not None else factory.option(*param_decls, **attrs))
return decorator
This decorator extends the default ``@click.option`` plus a factory option to use which type of option, for example: [click, click_option_group.optgroup]
"""
attrs.setdefault("help", "General option for OpenLLM CLI.")
factory = attrs.pop("factory", click)
def decorator(f: FC | None) -> FC:
return t.cast(FC, factory.option(*param_decls, **attrs)(f) if f is not None else factory.option(*param_decls, **attrs))
return decorator
def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = "pretty", **attrs: t.Any) -> t.Callable[[FC], FC]:
output = ["json", "pretty", "porcelain"]
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]: return [CompletionItem(it) for it in output]
return cli_option("-o", "--output", "output", type=click.Choice(output), default=default_value, help="Showing output type.", show_default=True,
envvar="OPENLLM_OUTPUT", show_envvar=True, shell_complete=complete_output_var, **attrs)(f)
def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--fast/--no-fast", show_default=True, default=False, envvar="OPENLLM_USE_LOCAL_LATEST", show_envvar=True,
help="""Whether to skip checking if models is already in store.
output = ["json", "pretty", "porcelain"]
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]:
return [CompletionItem(it) for it in output]
return cli_option("-o", "--output", "output", type=click.Choice(output), default=default_value, help="Showing output type.", show_default=True, envvar="OPENLLM_OUTPUT", show_envvar=True, shell_complete=complete_output_var, **attrs)(f)
def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
"--fast/--no-fast", show_default=True, default=False, envvar="OPENLLM_USE_LOCAL_LATEST", show_envvar=True, help="""Whether to skip checking if models is already in store.
This is useful if you already downloaded or setup the model beforehand.
""", **attrs)(f)
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
def model_id_option(f: _AnyCallable | None = None, *, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None,
show_envvar=model_env is not None,
help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f)
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-version", type=click.STRING, default=None,
help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
""", **attrs
)(f)
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
def model_id_option(f: _AnyCallable | None = None, *, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f)
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]:
arg = click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in CONFIG_MAPPING]), required=required)
return arg(f) if f is not None else arg
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--quantise", "--quantize", "quantize", type=click.Choice(["int8", "int4", "gptq"]), default=None,
envvar=model_env.quantize if model_env is not None else None, show_envvar=model_env is not None,
help="""Dynamic quantization for running this LLM.
arg = click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in CONFIG_MAPPING]), required=required)
return arg(f) if f is not None else arg
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
"--quantise", "--quantize", "quantize", type=click.Choice(["int8", "int4", "gptq"]), default=None, envvar=model_env.quantize if model_env is not None else None, show_envvar=model_env is not None, help="""Dynamic quantization for running this LLM.
The following quantization strategies are supported:
@@ -388,11 +375,16 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
**Note** that the model can also be served with quantized weights.
""" + ("""
**Note** that this will set the mode for serving within deployment.""" if build else "") + """
**Note** that quantization are currently only available in *PyTorch* models.""", **attrs)(f)
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--workers-per-resource", default=None, callback=workers_per_resource_callback, type=str, required=False,
help="""Number of workers per resource assigned.
""" + (
"""
**Note** that this will set the mode for serving within deployment.""" if build else ""
) + """
**Note** that quantization are currently only available in *PyTorch* models.""", **attrs
)(f)
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
"--workers-per-resource", default=None, callback=workers_per_resource_callback, type=str, required=False, help="""Number of workers per resource assigned.
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
for more information. By default, this is set to 1.
@@ -402,16 +394,22 @@ def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool =
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
- ``conserved``: This will determine the number of available GPU resources, and only assign one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
""" + ("""\n
""" + (
"""\n
**Note**: The workers value passed into 'build' will determine how the LLM can
be provisioned in Kubernetes as well as in standalone container. This will
ensure it has the same effect with 'openllm start --workers ...'""" if build else ""), **attrs)(f)
def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--bettertransformer", is_flag=True, default=None, envvar=model_env.bettertransformer if model_env is not None else None, show_envvar=model_env is not None,
help="Apply FasterTransformer wrapper to serve model. This will applies during serving time." if not build else "Set default environment variable whether to serve this model with FasterTransformer in build time.",
**attrs)(f)
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--serialisation", "--serialization", "serialisation_format", type=click.Choice(["safetensors", "legacy"]),
default="safetensors", show_default=True, show_envvar=True, envvar="OPENLLM_SERIALIZATION",
help="""Serialisation format for save/load LLM.
ensure it has the same effect with 'openllm start --workers ...'""" if build else ""
), **attrs
)(f)
def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
"--bettertransformer", is_flag=True, default=None, envvar=model_env.bettertransformer if model_env is not None else None, show_envvar=model_env is not None, help="Apply FasterTransformer wrapper to serve model. This will applies during serving time." if not build else "Set default environment variable whether to serve this model with FasterTransformer in build time.", **attrs
)(f)
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
"--serialisation", "--serialization", "serialisation_format", type=click.Choice(["safetensors", "legacy"]), default="safetensors", show_default=True, show_envvar=True, envvar="OPENLLM_SERIALIZATION", help="""Serialisation format for save/load LLM.
Currently the following strategies are supported:
@@ -429,29 +427,34 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
**Note** that GGML format is working in progress.
""", **attrs
)(f)
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--container-registry", "container_registry", type=str, default="ecr", show_default=True, show_envvar=True, envvar="OPENLLM_CONTAINER_REGISTRY",
callback=container_registry_callback,
help="""The default container registry to get the base image for building BentoLLM.
)(f)
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
"--container-registry", "container_registry", type=str, default="ecr", show_default=True, show_envvar=True, envvar="OPENLLM_CONTAINER_REGISTRY", callback=container_registry_callback, help="""The default container registry to get the base image for building BentoLLM.
Currently, it supports 'ecr', 'ghcr.io', 'docker.io'
\b
**Note** that in order to build the base image, you will need a GPUs to compile custom kernel. See ``openllm ext build-base-container`` for more information.
""")(f)
"""
)(f)
_wpr_strategies = {"round_robin", "conserved"}
def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
if value is None: return value
value = inflection.underscore(value)
if value in _wpr_strategies: return value
if value is None: return value
value = inflection.underscore(value)
if value in _wpr_strategies: return value
else:
try:
float(value) # type: ignore[arg-type]
except ValueError:
raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
else:
try: float(value) # type: ignore[arg-type]
except ValueError: raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
else: return value
return value
def container_registry_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
if value is None: return value
if value not in bundle.supported_registries: raise click.BadParameter(f"Value must be one of {bundle.supported_registries}", ctx, param)
return value
if value is None: return value
if value not in bundle.supported_registries: raise click.BadParameter(f"Value must be one of {bundle.supported_registries}", ctx, param)
return value

View File

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OpenLLM CLI Extension.
The following directory contains all possible extensions for OpenLLM CLI

View File

@@ -24,12 +24,11 @@ from .. import termui
from .._factory import machine_option
if t.TYPE_CHECKING:
from openllm.bundle.oci import LiteralContainerRegistry
from openllm.bundle.oci import LiteralContainerVersionStrategy
from openllm.bundle.oci import LiteralContainerRegistry
from openllm.bundle.oci import LiteralContainerVersionStrategy
@click.command("build_base_container",
context_settings=termui.CONTEXT_SETTINGS,
help="""Base image builder for BentoLLM.
@click.command(
"build_base_container", context_settings=termui.CONTEXT_SETTINGS, help="""Base image builder for BentoLLM.
By default, the base image will include custom kernels (PagedAttention via vllm, FlashAttention-v2, etc.) built with CUDA 11.8, Python 3.9 on Ubuntu22.04.
@@ -40,12 +39,13 @@ if t.TYPE_CHECKING:
This command is only useful for debugging and for building custom base image for extending BentoML with custom base images and custom kernels.
Note that we already release images on our CI to ECR and GHCR, so you don't need to build it yourself.
""")
"""
)
@click.option("--registry", multiple=True, type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)), help="Target registry to create image tag on.", default=None)
@click.option("--version-strategy", type=click.Choice(["release", "latest", "nightly"]), default="nightly", help="Version strategy to use for tagging the image.")
@click.option("--push/--no-push", help="Whether to push to remote repository", is_flag=True, default=False)
@machine_option
def cli(registry: tuple[LiteralContainerRegistry, ...] | None, version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]:
mapping = openllm.bundle.build_container(registry, version_strategy, push, machine)
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
return mapping
mapping = openllm.bundle.build_container(registry, version_strategy, push, machine)
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
return mapping

View File

@@ -28,7 +28,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
from .. import termui
if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
from bentoml._internal.bento import BentoStore
@click.command("dive_bentos", context_settings=termui.CONTEXT_SETTINGS)
@click.argument("bento", type=str)
@@ -36,12 +36,15 @@ if t.TYPE_CHECKING:
@click.pass_context
@inject
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
try: bentomodel = _bento_store.get(bento)
except bentoml.exceptions.NotFound: ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle": ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
if machine: return bentomodel.path
# copy and paste this into a new shell
if psutil.WINDOWS: subprocess.check_call([shutil.which("dir") or "dir"], cwd=bentomodel.path)
else:subprocess.check_call([shutil.which("tree") or "tree"], cwd=bentomodel.path)
ctx.exit(0)
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
try:
bentomodel = _bento_store.get(bento)
except bentoml.exceptions.NotFound:
ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle":
ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
if machine: return bentomodel.path
# copy and paste this into a new shell
if psutil.WINDOWS: subprocess.check_call([shutil.which("dir") or "dir"], cwd=bentomodel.path)
else: subprocess.check_call([shutil.which("tree") or "tree"], cwd=bentomodel.path)
ctx.exit(0)

View File

@@ -29,28 +29,28 @@ from .. import termui
from ...utils import bentoml_cattr
if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
from bentoml._internal.bento import BentoStore
@click.command("get_containerfile", context_settings=termui.CONTEXT_SETTINGS, help="Return Containerfile of any given Bento.")
@click.argument("bento", type=str)
@click.pass_context
@inject
def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str:
try: bentomodel = _bento_store.get(bento)
except bentoml.exceptions.NotFound: ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
# The logic below are similar to bentoml._internal.container.construct_containerfile
with open(bentomodel.path_of("bento.yaml"), "r") as f:
options = BentoInfo.from_yaml_file(f)
# NOTE: dockerfile_template is already included in the
# Dockerfile inside bento, and it is not relevant to
# construct_containerfile. Hence it is safe to set it to None here.
# See https://github.com/bentoml/BentoML/issues/3399.
docker_attrs = bentoml_cattr.unstructure(options.docker)
# NOTE: if users specify a dockerfile_template, we will
# save it to /env/docker/Dockerfile.template. This is necessary
# for the reconstruction of the Dockerfile.
if "dockerfile_template" in docker_attrs and docker_attrs["dockerfile_template"] is not None: docker_attrs["dockerfile_template"] = "env/docker/Dockerfile.template"
termui.echo(generate_containerfile(
docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True,
), fg="white")
return bentomodel.path
try:
bentomodel = _bento_store.get(bento)
except bentoml.exceptions.NotFound:
ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
# The logic below are similar to bentoml._internal.container.construct_containerfile
with open(bentomodel.path_of("bento.yaml"), "r") as f:
options = BentoInfo.from_yaml_file(f)
# NOTE: dockerfile_template is already included in the
# Dockerfile inside bento, and it is not relevant to
# construct_containerfile. Hence it is safe to set it to None here.
# See https://github.com/bentoml/BentoML/issues/3399.
docker_attrs = bentoml_cattr.unstructure(options.docker)
# NOTE: if users specify a dockerfile_template, we will
# save it to /env/docker/Dockerfile.template. This is necessary
# for the reconstruction of the Dockerfile.
if "dockerfile_template" in docker_attrs and docker_attrs["dockerfile_template"] is not None: docker_attrs["dockerfile_template"] = "env/docker/Dockerfile.template"
termui.echo(generate_containerfile(docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True,), fg="white")
return bentomodel.path

View File

@@ -26,7 +26,7 @@ from .. import termui
from ..._prompt import process_prompt
if t.TYPE_CHECKING:
from ..entrypoint import LiteralOutput
from ..entrypoint import LiteralOutput
@click.command("get_prompt", context_settings=termui.CONTEXT_SETTINGS)
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))
@@ -37,27 +37,29 @@ if t.TYPE_CHECKING:
@click.option("--opt", help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]")
@click.pass_context
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, output: LiteralOutput, machine: bool, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
"""Get the default prompt used by OpenLLM."""
module = openllm.utils.EnvVarMixin(model_name).module
_memoized = {k: v[0] for k, v in _memoized.items() if v}
try:
template = getattr(module, "DEFAULT_PROMPT_TEMPLATE", None)
prompt_mapping = getattr(module, "PROMPT_MAPPING", None)
if template is None: raise click.BadArgumentUsage(f"model {model_name} does not have a default prompt template") from None
if callable(template):
if format is None:
if not hasattr(module, "PROMPT_MAPPING") or module.PROMPT_MAPPING is None: raise RuntimeError("Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.")
raise click.BadOptionUsage("format", f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
if prompt_mapping is None: raise click.BadArgumentUsage(f"Failed to fine prompt mapping while the default prompt for {model_name} is a callable.") from None
if format not in prompt_mapping: raise click.BadOptionUsage("format", f"Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})")
_prompt_template = template(format)
else: _prompt_template = template
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
if machine: return repr(fully_formatted)
elif output == "porcelain": termui.echo(repr(fully_formatted), fg="white")
elif output == "json": termui.echo(orjson.dumps({"prompt": fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg="white")
else:
termui.echo(f"== Prompt for {model_name} ==\n", fg="magenta")
termui.echo(fully_formatted, fg="white")
except AttributeError: raise click.ClickException(f"Failed to determine a default prompt template for {model_name}.") from None
ctx.exit(0)
"""Get the default prompt used by OpenLLM."""
module = openllm.utils.EnvVarMixin(model_name).module
_memoized = {k: v[0] for k, v in _memoized.items() if v}
try:
template = getattr(module, "DEFAULT_PROMPT_TEMPLATE", None)
prompt_mapping = getattr(module, "PROMPT_MAPPING", None)
if template is None: raise click.BadArgumentUsage(f"model {model_name} does not have a default prompt template") from None
if callable(template):
if format is None:
if not hasattr(module, "PROMPT_MAPPING") or module.PROMPT_MAPPING is None: raise RuntimeError("Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.")
raise click.BadOptionUsage("format", f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
if prompt_mapping is None: raise click.BadArgumentUsage(f"Failed to fine prompt mapping while the default prompt for {model_name} is a callable.") from None
if format not in prompt_mapping: raise click.BadOptionUsage("format", f"Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})")
_prompt_template = template(format)
else:
_prompt_template = template
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
if machine: return repr(fully_formatted)
elif output == "porcelain": termui.echo(repr(fully_formatted), fg="white")
elif output == "json": termui.echo(orjson.dumps({"prompt": fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg="white")
else:
termui.echo(f"== Prompt for {model_name} ==\n", fg="magenta")
termui.echo(fully_formatted, fg="white")
except AttributeError:
raise click.ClickException(f"Failed to determine a default prompt template for {model_name}.") from None
ctx.exit(0)

View File

@@ -26,9 +26,9 @@ from .. import termui
@click.command("list_bentos", context_settings=termui.CONTEXT_SETTINGS)
@click.pass_context
def cli(ctx: click.Context) -> None:
"""List available bentos built by OpenLLM."""
_local_bentos = {str(i.tag): i.info.labels["start_name"] for i in bentoml.list() if "start_name" in i.info.labels}
mapping = {k: [tag for tag, name in _local_bentos.items() if name == k] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())}
mapping = {k: v for k, v in mapping.items() if v}
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
ctx.exit(0)
"""List available bentos built by OpenLLM."""
_local_bentos = {str(i.tag): i.info.labels["start_name"] for i in bentoml.list() if "start_name" in i.info.labels}
mapping = {k: [tag for tag, name in _local_bentos.items() if name == k] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())}
mapping = {k: v for k, v in mapping.items() if v}
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
ctx.exit(0)

View File

@@ -26,16 +26,16 @@ from .. import termui
from .._factory import model_name_argument
if t.TYPE_CHECKING:
from ..._types import DictStrAny
from ..._types import DictStrAny
@click.command("list_models", context_settings=termui.CONTEXT_SETTINGS)
@model_name_argument(required=False)
def cli(model_name: str | None) -> DictStrAny:
"""This is equivalent to openllm models --show-available less the nice table."""
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
ids_in_local_store = {k: [i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k] for k in models}
if model_name is not None: ids_in_local_store = {k: [i for i in v if "model_name" in i.info.labels and i.info.labels["model_name"] == inflection.dasherize(model_name)] for k,v in ids_in_local_store.items()}
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()}
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white")
return local_models
"""This is equivalent to openllm models --show-available less the nice table."""
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
ids_in_local_store = {k: [i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k] for k in models}
if model_name is not None: ids_in_local_store = {k: [i for i in v if "model_name" in i.info.labels and i.info.labels["model_name"] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items()}
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()}
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white")
return local_models

View File

@@ -23,13 +23,11 @@ from ..utils import get_debug_mode
from ..utils import get_quiet_mode
def echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.Any) -> None:
attrs["fg"], call = fg if not get_debug_mode() else None, click.echo if not _with_style else click.secho
if not get_quiet_mode(): call(text, **attrs)
attrs["fg"], call = fg if not get_debug_mode() else None, click.echo if not _with_style else click.secho
if not get_quiet_mode(): call(text, **attrs)
COLUMNS = int(os.getenv("COLUMNS", str(120)))
CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"], "max_content_width": COLUMNS, "token_normalize_func": inflection.underscore}
__all__ = ["echo", "COLUMNS", "CONTEXT_SETTINGS"]

View File

@@ -24,23 +24,25 @@ import importlib
import itertools
import typing as t
_import_structure: dict[str, list[str]] = {
"runtimes.grpc": ["AsyncGrpcClient", "GrpcClient"],
"runtimes.http": ["AsyncHTTPClient", "HTTPClient"],
}
_import_structure: dict[str, list[str]] = {"runtimes.grpc": ["AsyncGrpcClient", "GrpcClient"], "runtimes.http": ["AsyncHTTPClient", "HTTPClient"],}
if t.TYPE_CHECKING:
from openllm_client import AsyncGrpcClient as AsyncGrpcClient
from openllm_client import AsyncHTTPClient as AsyncHTTPClient
from openllm_client import GrpcClient as GrpcClient
from openllm_client import HTTPClient as HTTPClient
from openllm_client import AsyncGrpcClient as AsyncGrpcClient
from openllm_client import AsyncHTTPClient as AsyncHTTPClient
from openllm_client import GrpcClient as GrpcClient
from openllm_client import HTTPClient as HTTPClient
_module = "openllm_client"
__all__ = list(itertools.chain.from_iterable(_import_structure.values()))
def __dir__() -> list[str]: return sorted(__all__)
def __dir__() -> list[str]:
return sorted(__all__)
def __getattr__(name: str) -> t.Any:
if name in _import_structure: return importlib.import_module(f".{name}", _module)
try: module = next(module for module, attrs in _import_structure.items() if name in attrs)
except StopIteration: raise AttributeError(f"module {_module} has no attribute {name}") from None
return getattr(importlib.import_module(f".{module}", _module), name)
if name in _import_structure: return importlib.import_module(f".{name}", _module)
try:
module = next(module for module, attrs in _import_structure.items() if name in attrs)
except StopIteration:
raise AttributeError(f"module {_module} has no attribute {name}") from None
return getattr(importlib.import_module(f".{module}", _module), name)

View File

@@ -17,30 +17,25 @@ from __future__ import annotations
import bentoml
class OpenLLMException(bentoml.exceptions.BentoMLException):
"""Base class for all OpenLLM exceptions. This extends BentoMLException."""
"""Base class for all OpenLLM exceptions. This extends BentoMLException."""
class GpuNotAvailableError(OpenLLMException):
"""Raised when there is no GPU available in given system."""
"""Raised when there is no GPU available in given system."""
class ValidationError(OpenLLMException):
"""Raised when a validation fails."""
"""Raised when a validation fails."""
class ForbiddenAttributeError(OpenLLMException):
"""Raised when using an _internal field."""
"""Raised when using an _internal field."""
class MissingAnnotationAttributeError(OpenLLMException):
"""Raised when a field under openllm.LLMConfig is missing annotations."""
"""Raised when a field under openllm.LLMConfig is missing annotations."""
class MissingDependencyError(BaseException):
"""Raised when a dependency is missing."""
"""Raised when a dependency is missing."""
class Error(BaseException):
"""To be used instead of naked raise."""
"""To be used instead of naked raise."""
class FineTuneStrategyNotSupportedError(OpenLLMException):
"""Raised when a fine-tune strategy is not supported for given LLM."""
"""Raised when a fine-tune strategy is not supported for given LLM."""

View File

@@ -23,20 +23,21 @@ _MODELS: set[str] = {'auto', 'baichuan', 'chatglm', 'dolly_v2', 'falcon', 'flan_
# fmt: on
if t.TYPE_CHECKING:
# fmt: off
# update-models-import.py: start types
from . import auto as auto
from . import baichuan as baichuan
from . import chatglm as chatglm
from . import dolly_v2 as dolly_v2
from . import falcon as falcon
from . import flan_t5 as flan_t5
from . import gpt_neox as gpt_neox
from . import llama as llama
from . import mpt as mpt
from . import opt as opt
from . import stablelm as stablelm
from . import starcoder as starcoder
# update-models-import.py: stop types
# fmt: on
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS}, module_spec=__spec__)
# fmt: off
# update-models-import.py: start types
from . import auto as auto
from . import baichuan as baichuan
from . import chatglm as chatglm
from . import dolly_v2 as dolly_v2
from . import falcon as falcon
from . import flan_t5 as flan_t5
from . import gpt_neox as gpt_neox
from . import llama as llama
from . import mpt as mpt
from . import opt as opt
from . import stablelm as stablelm
from . import starcoder as starcoder
# update-models-import.py: stop types
# fmt: on
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS}, module_spec=__spec__)

View File

@@ -20,51 +20,63 @@ from ...utils import is_flax_available
from ...utils import is_tf_available
from ...utils import is_torch_available
from ...utils import is_vllm_available
_import_structure: dict[str, list[str]] = {
"configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"],
"modeling_auto": ["MODEL_MAPPING_NAMES"],
"modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"],
"modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"],
"modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"],
}
_import_structure: dict[str, list[str]] = {"configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"], "modeling_auto": ["MODEL_MAPPING_NAMES"], "modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"], "modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"], "modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"],}
try:
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: _import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"])
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"])
try:
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: _import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
try:
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: _import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
try:
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: _import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
if t.TYPE_CHECKING:
from .configuration_auto import CONFIG_MAPPING as CONFIG_MAPPING
from .configuration_auto import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
from .configuration_auto import AutoConfig as AutoConfig
from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
from .modeling_vllm_auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
try:
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM
try:
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM
try:
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
try:
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_auto import CONFIG_MAPPING as CONFIG_MAPPING
from .configuration_auto import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
from .configuration_auto import AutoConfig as AutoConfig
from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
from .modeling_vllm_auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
try:
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM
try:
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM
try:
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
try:
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -18,77 +18,82 @@ import inflection
import openllm
from ...utils import ReprMixin
if t.TYPE_CHECKING:
import types
from collections import _odict_items, _odict_keys, _odict_values
ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]]
ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]]
ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]]
ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]]
import types
from collections import _odict_items, _odict_keys, _odict_values
ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]]
ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]]
ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]]
ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]]
else:
ConfigKeysView = ConfigValuesView = ConfigItemsView = t.Any
ConfigOrderedDict = OrderedDict
ConfigKeysView = ConfigValuesView = ConfigItemsView = t.Any
ConfigOrderedDict = OrderedDict
# NOTE: This is the entrypoint when adding new model config
CONFIG_MAPPING_NAMES = OrderedDict(
[
("chatglm", "ChatGLMConfig"),
("dolly_v2", "DollyV2Config"),
("falcon", "FalconConfig"),
("flan_t5", "FlanT5Config"),
("gpt_neox", "GPTNeoXConfig"),
("llama", "LlamaConfig"),
("mpt", "MPTConfig"),
("opt", "OPTConfig"),
("stablelm", "StableLMConfig"),
("starcoder", "StarCoderConfig"),
("baichuan", "BaichuanConfig"),
]
)
CONFIG_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLMConfig"), ("dolly_v2", "DollyV2Config"), ("falcon", "FalconConfig"), ("flan_t5", "FlanT5Config"), ("gpt_neox", "GPTNeoXConfig"), ("llama", "LlamaConfig"), ("mpt", "MPTConfig"), ("opt", "OPTConfig"), ("stablelm", "StableLMConfig"), ("starcoder", "StarCoderConfig"), ("baichuan", "BaichuanConfig"),])
class _LazyConfigMapping(ConfigOrderedDict, ReprMixin):
def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]):
self._mapping = mapping
self._extra_content: dict[str, t.Any] = {}
self._modules: dict[str, types.ModuleType] = {}
def __getitem__(self, key: str) -> t.Any:
if key in self._extra_content: return self._extra_content[key]
if key not in self._mapping:
if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key))
raise KeyError(key)
value, module_name = self._mapping[key], inflection.underscore(key)
if module_name not in self._modules: self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module
if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value)
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level.
return getattr(openllm, value)
@property
def __repr_keys__(self) -> set[str]: return set(self._mapping.keys())
def __repr__(self) -> str: return ReprMixin.__repr__(self)
def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]: yield from self._mapping.items()
def keys(self): return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys()))
def values(self): return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()))
def items(self): return t.cast(ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()))
def __iter__(self): return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
def __contains__(self, item: t.Any): return item in self._mapping or item in self._extra_content
def register(self, key: str, value: t.Any):
if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
self._extra_content[key] = value
def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]):
self._mapping = mapping
self._extra_content: dict[str, t.Any] = {}
self._modules: dict[str, types.ModuleType] = {}
def __getitem__(self, key: str) -> t.Any:
if key in self._extra_content: return self._extra_content[key]
if key not in self._mapping:
if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key))
raise KeyError(key)
value, module_name = self._mapping[key], inflection.underscore(key)
if module_name not in self._modules: self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module
if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value)
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level.
return getattr(openllm, value)
@property
def __repr_keys__(self) -> set[str]:
return set(self._mapping.keys())
def __repr__(self) -> str:
return ReprMixin.__repr__(self)
def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]:
yield from self._mapping.items()
def keys(self):
return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys()))
def values(self):
return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()))
def items(self):
return t.cast(ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()))
def __iter__(self):
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
def __contains__(self, item: t.Any):
return item in self._mapping or item in self._extra_content
def register(self, key: str, value: t.Any):
if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
self._extra_content[key] = value
CONFIG_MAPPING: dict[str, type[openllm.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
# The below handle special alias when we call underscore to the name directly
# without processing camelcase first.
CONFIG_NAME_ALIASES: dict[str, str] = {
"chat_glm": "chatglm",
"stable_lm": "stablelm",
"star_coder": "starcoder",
"gpt_neo_x": "gpt_neox",
}
CONFIG_NAME_ALIASES: dict[str, str] = {"chat_glm": "chatglm", "stable_lm": "stablelm", "star_coder": "starcoder", "gpt_neo_x": "gpt_neox",}
class AutoConfig:
def __init__(self, *_: t.Any, **__: t.Any): raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.")
@classmethod
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig:
model_name = inflection.underscore(model_name)
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
@classmethod
def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]:
model_name = inflection.underscore(name)
if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name]
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name]
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
def __init__(self, *_: t.Any, **__: t.Any):
raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.")
@classmethod
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig:
model_name = inflection.underscore(model_name)
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
@classmethod
def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]:
model_name = inflection.underscore(name)
if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name]
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name]
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")

View File

@@ -21,35 +21,40 @@ import inflection
import openllm
from ...utils import ReprMixin
if t.TYPE_CHECKING:
import types
from ..._llm import LLMRunner
from collections import _odict_items, _odict_keys, _odict_values
ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
import types
from ..._llm import LLMRunner
from collections import _odict_items, _odict_keys, _odict_values
ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
else:
ConfigModelKeysView = ConfigModelValuesView = ConfigModelItemsView = t.Any
ConfigModelOrderedDict = OrderedDict
ConfigModelKeysView = ConfigModelValuesView = ConfigModelItemsView = t.Any
ConfigModelOrderedDict = OrderedDict
logger = logging.getLogger(__name__)
class BaseAutoLLMClass:
_model_mapping: _LazyAutoMapping
def __init__(self, *args: t.Any, **attrs: t.Any): raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.")
@classmethod
def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]:
"""The lower level API for creating a LLM instance.
_model_mapping: _LazyAutoMapping
def __init__(self, *args: t.Any, **attrs: t.Any):
raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.")
@classmethod
def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]:
"""The lower level API for creating a LLM instance.
```python
>>> import openllm
>>> llm = openllm.AutoLLM.for_model("flan-t5")
```
"""
llm = cls.infer_class_from_name(model).from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs)
if ensure_available: llm.ensure_model_id_exists()
return llm
@classmethod
def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]:
"""Create a LLM Runner for the given model name.
llm = cls.infer_class_from_name(model).from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs)
if ensure_available: llm.ensure_model_id_exists()
return llm
@classmethod
def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]:
"""Create a LLM Runner for the given model name.
Args:
model: The model name to instantiate.
@@ -59,76 +64,107 @@ class BaseAutoLLMClass:
Returns:
A LLM instance.
"""
runner_kwargs_name = set(inspect.signature(openllm.LLM[t.Any, t.Any].to_runner).parameters)
runner_attrs = {k: v for k, v in attrs.items() if k in runner_kwargs_name}
for k in runner_attrs: del attrs[k]
return cls.for_model(model, model_id=model_id, **attrs).to_runner(**runner_attrs)
@classmethod
def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]):
"""Register a new model for this class.
runner_kwargs_name = set(inspect.signature(openllm.LLM[t.Any, t.Any].to_runner).parameters)
runner_attrs = {k: v for k, v in attrs.items() if k in runner_kwargs_name}
for k in runner_attrs:
del attrs[k]
return cls.for_model(model, model_id=model_id, **attrs).to_runner(**runner_attrs)
@classmethod
def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]):
"""Register a new model for this class.
Args:
config_class: The configuration corresponding to the model to register.
llm_class: The runnable to register.
"""
if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class: raise ValueError(f"The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!")
cls._model_mapping.register(config_class, llm_class)
@classmethod
def infer_class_from_name(cls, name: str) -> type[openllm.LLM[t.Any, t.Any]]:
config_class = openllm.AutoConfig.infer_class_from_name(name)
if config_class in cls._model_mapping: return cls._model_mapping[config_class]
raise ValueError(f"Unrecognized configuration class ({config_class}) for {name}. Model name should be one of {', '.join(openllm.CONFIG_MAPPING.keys())} (Registered configuration class: {', '.join([i.__name__ for i in cls._model_mapping.keys()])}).")
if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class:
raise ValueError(f"The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!")
cls._model_mapping.register(config_class, llm_class)
@classmethod
def infer_class_from_name(cls, name: str) -> type[openllm.LLM[t.Any, t.Any]]:
config_class = openllm.AutoConfig.infer_class_from_name(name)
if config_class in cls._model_mapping: return cls._model_mapping[config_class]
raise ValueError(f"Unrecognized configuration class ({config_class}) for {name}. Model name should be one of {', '.join(openllm.CONFIG_MAPPING.keys())} (Registered configuration class: {', '.join([i.__name__ for i in cls._model_mapping.keys()])}).")
def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
if attr is None: return
if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr): return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the object at the top level.
openllm_module = importlib.import_module("openllm")
if module != openllm_module:
try: return getattribute_from_module(openllm_module, attr)
except ValueError: raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
raise ValueError(f"Could not find {attr} in {openllm_module}!")
if attr is None: return
if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr): return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the object at the top level.
openllm_module = importlib.import_module("openllm")
if module != openllm_module:
try:
return getattribute_from_module(openllm_module, attr)
except ValueError:
raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
raise ValueError(f"Could not find {attr} in {openllm_module}!")
class _LazyAutoMapping(ConfigModelOrderedDict, ReprMixin):
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
This OrderedDict values() and keys() returns the list instead, so you don't
have to do list(mapping.values()) to get the list of values.
"""
def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._extra_content: dict[t.Any, t.Any] = {}
self._modules: dict[str, types.ModuleType] = {}
def __len__(self): return len(set(self._config_mapping.keys()).intersection(self._model_mapping.keys())) + len(self._extra_content)
def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]:
if key in self._extra_content: return self._extra_content[key]
model_type = self._reverse_config_mapping[key.__name__]
if model_type in self._model_mapping: return self._load_attr_from_module(model_type, self._model_mapping[model_type])
# Maybe there was several model types associated with this config.
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
for mtype in model_types:
if mtype in self._model_mapping: return self._load_attr_from_module(mtype, self._model_mapping[mtype])
raise KeyError(key)
def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any:
module_name = inflection.underscore(model_type)
if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models")
return getattribute_from_module(self._modules[module_name], attr)
def keys(self): return t.cast(ConfigModelKeysView, [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys()))
@property
def __repr_keys__(self) -> set[str]: return set(self._config_mapping.keys())
def __repr__(self) -> str: return ReprMixin.__repr__(self)
def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]: yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping)
def __bool__(self): return bool(self.keys())
def values(self): return t.cast(ConfigModelValuesView, [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(self._extra_content.values()))
def items(self): return t.cast(ConfigModelItemsView, [(self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items()))
def __iter__(self): return iter(self.keys())
def __contains__(self, item: t.Any):
if item in self._extra_content: return True
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: return False
return self._reverse_config_mapping[item.__name__] in self._model_mapping
def register(self, key: t.Any, value: t.Any):
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.")
self._extra_content[key] = value
def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._extra_content: dict[t.Any, t.Any] = {}
self._modules: dict[str, types.ModuleType] = {}
def __len__(self):
return len(set(self._config_mapping.keys()).intersection(self._model_mapping.keys())) + len(self._extra_content)
def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]:
if key in self._extra_content: return self._extra_content[key]
model_type = self._reverse_config_mapping[key.__name__]
if model_type in self._model_mapping: return self._load_attr_from_module(model_type, self._model_mapping[model_type])
# Maybe there was several model types associated with this config.
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
for mtype in model_types:
if mtype in self._model_mapping: return self._load_attr_from_module(mtype, self._model_mapping[mtype])
raise KeyError(key)
def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any:
module_name = inflection.underscore(model_type)
if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models")
return getattribute_from_module(self._modules[module_name], attr)
def keys(self):
return t.cast(ConfigModelKeysView, [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys()))
@property
def __repr_keys__(self) -> set[str]:
return set(self._config_mapping.keys())
def __repr__(self) -> str:
return ReprMixin.__repr__(self)
def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]:
yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping)
def __bool__(self):
return bool(self.keys())
def values(self):
return t.cast(ConfigModelValuesView, [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(self._extra_content.values()))
def items(self):
return t.cast(ConfigModelItemsView, [(self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items()))
def __iter__(self):
return iter(self.keys())
def __contains__(self, item: t.Any):
if item in self._extra_content: return True
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: return False
return self._reverse_config_mapping[item.__name__] in self._model_mapping
def register(self, key: t.Any, value: t.Any):
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.")
self._extra_content[key] = value
__all__ = ["BaseAutoLLMClass", "_LazyAutoMapping"]

View File

@@ -17,21 +17,9 @@ from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
MODEL_MAPPING_NAMES = OrderedDict(
[
("chatglm", "ChatGLM"),
("dolly_v2", "DollyV2"),
("falcon", "Falcon"),
("flan_t5", "FlanT5"),
("gpt_neox", "GPTNeoX"),
("llama", "Llama"),
("mpt", "MPT"),
("opt", "OPT"),
("stablelm", "StableLM"),
("starcoder", "StarCoder"),
("baichuan", "Baichuan"),
]
)
MODEL_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLM"), ("dolly_v2", "DollyV2"), ("falcon", "Falcon"), ("flan_t5", "FlanT5"), ("gpt_neox", "GPTNeoX"), ("llama", "Llama"), ("mpt", "MPT"), ("opt", "OPT"), ("stablelm", "StableLM"), ("starcoder", "StarCoder"), ("baichuan", "Baichuan"),])
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
class AutoLLM(BaseAutoLLMClass):
_model_mapping = MODEL_MAPPING
_model_mapping = MODEL_MAPPING

View File

@@ -16,12 +16,9 @@ from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
MODEL_FLAX_MAPPING_NAMES = OrderedDict(
[
("flan_t5", "FlaxFlanT5"),
("opt", "FlaxOPT"),
]
)
MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT"),])
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
class AutoFlaxLLM(BaseAutoLLMClass):
_model_mapping = MODEL_FLAX_MAPPING
_model_mapping = MODEL_FLAX_MAPPING

View File

@@ -16,12 +16,9 @@ from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
MODEL_TF_MAPPING_NAMES = OrderedDict(
[
("flan_t5", "TFFlanT5"),
("opt", "TFOPT"),
]
)
MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5"), ("opt", "TFOPT"),])
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
class AutoTFLLM(BaseAutoLLMClass):
_model_mapping = MODEL_TF_MAPPING
_model_mapping = MODEL_TF_MAPPING

View File

@@ -16,12 +16,9 @@ from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
MODEL_VLLM_MAPPING_NAMES = OrderedDict(
[
("llama", "VLLMLlama"),
("opt", "VLLMOPT")
]
)
MODEL_VLLM_MAPPING_NAMES = OrderedDict([("llama", "VLLMLlama"), ("opt", "VLLMOPT")])
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
class AutoVLLM(BaseAutoLLMClass):
_model_mapping = MODEL_VLLM_MAPPING
_model_mapping = MODEL_VLLM_MAPPING

View File

@@ -18,18 +18,24 @@ from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_cpm_kernels_available
from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {"configuration_baichuan": ["BaichuanConfig", "START_BAICHUAN_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
try:
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_baichuan"] = ["Baichuan"]
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_baichuan"] = ["Baichuan"]
if t.TYPE_CHECKING:
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_baichuan import START_BAICHUAN_COMMAND_DOCSTRING as START_BAICHUAN_COMMAND_DOCSTRING
from .configuration_baichuan import BaichuanConfig as BaichuanConfig
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_baichuan import START_BAICHUAN_COMMAND_DOCSTRING as START_BAICHUAN_COMMAND_DOCSTRING
from .configuration_baichuan import BaichuanConfig as BaichuanConfig
try:
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_baichuan import Baichuan as Baichuan
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
try:
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_baichuan import Baichuan as Baichuan
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -13,8 +13,9 @@
# limitations under the License.
from __future__ import annotations
import openllm
class BaichuanConfig(openllm.LLMConfig):
"""Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology.
"""Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology.
Baichuan-7B is based on Transformer architecture,
which contains 7 billion parameters and trained on approximately 1.2 trillion tokens.
@@ -23,28 +24,16 @@ class BaichuanConfig(openllm.LLMConfig):
and English benchmarks (C-Eval, MMLU, etc).
Refer to [Baichuan-7B's GitHub page](https://github.com/baichuan-inc/Baichuan-7B) for more information.
"""
__config__ = {
"name_type": "lowercase",
"trust_remote_code": True,
"timeout": 3600000,
"requires_gpu": True,
"url": "https://github.com/baichuan-inc/Baichuan-7B",
"requirements": ["cpm-kernels", "sentencepiece"],
"architecture": "BaiChuanForCausalLM",
"default_id": "baichuan-inc/baichuan-7b",
"model_ids": [
"baichuan-inc/baichuan-7b",
"baichuan-inc/baichuan-13b-base",
"baichuan-inc/baichuan-13b-chat",
"fireballoon/baichuan-vicuna-chinese-7b",
"fireballoon/baichuan-vicuna-7b",
"hiyouga/baichuan-7b-sft",
],
}
class GenerationConfig:
max_new_tokens: int = 2048
top_p: float = 0.7
temperature: float = 0.95
__config__ = {
"name_type": "lowercase", "trust_remote_code": True, "timeout": 3600000, "requires_gpu": True, "url": "https://github.com/baichuan-inc/Baichuan-7B", "requirements": ["cpm-kernels", "sentencepiece"], "architecture": "BaiChuanForCausalLM", "default_id": "baichuan-inc/baichuan-7b",
"model_ids": ["baichuan-inc/baichuan-7b", "baichuan-inc/baichuan-13b-base", "baichuan-inc/baichuan-13b-chat", "fireballoon/baichuan-vicuna-chinese-7b", "fireballoon/baichuan-vicuna-7b", "hiyouga/baichuan-7b-sft",],
}
class GenerationConfig:
max_new_tokens: int = 2048
top_p: float = 0.7
temperature: float = 0.95
START_BAICHUAN_COMMAND_DOCSTRING = """\
Run a LLMServer for Baichuan model.

View File

@@ -18,12 +18,18 @@ from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import process_prompt
if t.TYPE_CHECKING: import torch, transformers
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
class Baichuan(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

View File

@@ -18,17 +18,23 @@ from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_cpm_kernels_available
from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {"configuration_chatglm": ["ChatGLMConfig", "START_CHATGLM_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
try:
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_chatglm"] = ["ChatGLM"]
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_chatglm"] = ["ChatGLM"]
if t.TYPE_CHECKING:
from .configuration_chatglm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_chatglm import START_CHATGLM_COMMAND_DOCSTRING as START_CHATGLM_COMMAND_DOCSTRING
from .configuration_chatglm import ChatGLMConfig as ChatGLMConfig
try:
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_chatglm import ChatGLM as ChatGLM
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_chatglm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_chatglm import START_CHATGLM_COMMAND_DOCSTRING as START_CHATGLM_COMMAND_DOCSTRING
from .configuration_chatglm import ChatGLMConfig as ChatGLMConfig
try:
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_chatglm import ChatGLM as ChatGLM
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -13,8 +13,9 @@
# limitations under the License.
from __future__ import annotations
import openllm
class ChatGLMConfig(openllm.LLMConfig):
"""ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
"""ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
With the quantization technique, users can deploy locally on consumer-grade graphics cards
(only 6GB of GPU memory is required at the INT4 quantization level).
@@ -27,34 +28,20 @@ class ChatGLMConfig(openllm.LLMConfig):
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
"""
__config__ = {
"name_type": "lowercase",
"trust_remote_code": True,
"timeout": 3600000,
"requires_gpu": True,
"url": "https://github.com/THUDM/ChatGLM-6B",
"requirements": ["cpm-kernels", "sentencepiece"],
"architecture": "ChatGLMForConditionalGeneration",
"default_id": "thudm/chatglm-6b",
"model_ids": [
"thudm/chatglm-6b",
"thudm/chatglm-6b-int8",
"thudm/chatglm-6b-int4",
"thudm/chatglm2-6b",
"thudm/chatglm2-6b-int4",
],
}
retain_history: bool = openllm.LLMConfig.Field(
False,
description="""Whether to retain history given to the model.
If set to True, then the model will retain given history.""",
)
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
class GenerationConfig:
max_new_tokens: int = 2048
num_beams: int = 1
top_p: float = 0.7
temperature: float = 0.95
__config__ = {
"name_type": "lowercase", "trust_remote_code": True, "timeout": 3600000, "requires_gpu": True, "url": "https://github.com/THUDM/ChatGLM-6B", "requirements": ["cpm-kernels", "sentencepiece"], "architecture": "ChatGLMForConditionalGeneration", "default_id": "thudm/chatglm-6b",
"model_ids": ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4", "thudm/chatglm2-6b", "thudm/chatglm2-6b-int4",],
}
retain_history: bool = openllm.LLMConfig.Field(False, description="""Whether to retain history given to the model.
If set to True, then the model will retain given history.""",)
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
class GenerationConfig:
max_new_tokens: int = 2048
num_beams: int = 1
top_p: float = 0.7
temperature: float = 0.95
START_CHATGLM_COMMAND_DOCSTRING = """\
Run a LLMServer for ChatGLM model.

View File

@@ -17,38 +17,45 @@ import openllm
from ..._llm import LLMEmbeddings
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, num_beams: int | None = None, top_p: float | None = None, temperature: float | None = None, chat_history: list[str] | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
prompt_text = ""
if use_default_prompt_template and chat_history is not None:
for i, (old_query, response) in enumerate(chat_history): prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" # noqa: RUF001
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" # noqa: RUF001
else: prompt_text = prompt
postprocess_generate_kwargs = {"chat_history": chat_history if chat_history is not None else None}
# NOTE: The rest of attrs should be kwargs for GenerationConfig
generate_kwargs = {"max_new_tokens": max_new_tokens, "num_beams": num_beams, "top_p": top_p, "temperature": temperature, **attrs}
return prompt_text, generate_kwargs, postprocess_generate_kwargs
def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any):
generated, history = generation_result
if self.config.retain_history:
assert chat_history is not None, "'retain_history' is True while there is no history provided."
chat_history.extend(history)
return generated
def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, str]]]:
with torch.inference_mode():
self.model.eval()
# Only use half precision if the model is not yet quantized
if self.config.use_half_precision: self.model.half()
return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
embeddings: list[list[float]] = []
num_tokens = 0
for prompt in prompts:
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
with torch.inference_mode():
outputs = self.model(input_ids, output_hidden_states=True)
data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0)
embeddings.append(data.tolist())
num_tokens += len(input_ids[0])
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, num_beams: int | None = None, top_p: float | None = None, temperature: float | None = None, chat_history: list[str] | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
prompt_text = ""
if use_default_prompt_template and chat_history is not None:
for i, (old_query, response) in enumerate(chat_history):
prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" # noqa: RUF001
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" # noqa: RUF001
else:
prompt_text = prompt
postprocess_generate_kwargs = {"chat_history": chat_history if chat_history is not None else None}
# NOTE: The rest of attrs should be kwargs for GenerationConfig
generate_kwargs = {"max_new_tokens": max_new_tokens, "num_beams": num_beams, "top_p": top_p, "temperature": temperature, **attrs}
return prompt_text, generate_kwargs, postprocess_generate_kwargs
def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any):
generated, history = generation_result
if self.config.retain_history:
assert chat_history is not None, "'retain_history' is True while there is no history provided."
chat_history.extend(history)
return generated
def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, str]]]:
with torch.inference_mode():
self.model.eval()
# Only use half precision if the model is not yet quantized
if self.config.use_half_precision: self.model.half()
return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
embeddings: list[list[float]] = []
num_tokens = 0
for prompt in prompts:
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
with torch.inference_mode():
outputs = self.model(input_ids, output_hidden_states=True)
data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0)
embeddings.append(data.tolist())
num_tokens += len(input_ids[0])
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)

View File

@@ -17,17 +17,23 @@ import typing as t
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {"configuration_dolly_v2": ["DollyV2Config", "START_DOLLY_V2_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_dolly_v2"] = ["DollyV2"]
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_dolly_v2"] = ["DollyV2"]
if t.TYPE_CHECKING:
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_dolly_v2 import START_DOLLY_V2_COMMAND_DOCSTRING as START_DOLLY_V2_COMMAND_DOCSTRING
from .configuration_dolly_v2 import DollyV2Config as DollyV2Config
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_dolly_v2 import DollyV2 as DollyV2
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_dolly_v2 import START_DOLLY_V2_COMMAND_DOCSTRING as START_DOLLY_V2_COMMAND_DOCSTRING
from .configuration_dolly_v2 import DollyV2Config as DollyV2Config
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_dolly_v2 import DollyV2 as DollyV2
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -15,8 +15,9 @@ from __future__ import annotations
import typing as t
import openllm
if t.TYPE_CHECKING: import transformers
class DollyV2Config(openllm.LLMConfig):
"""Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.
"""Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.
Based on pythia-12b, Dolly is trained on ~15k instruction/response fine tuning records databricks-dolly-15k
generated by Databricks employees in capability domains from the InstructGPT paper, including brainstorming,
@@ -27,22 +28,16 @@ class DollyV2Config(openllm.LLMConfig):
Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
"""
__config__ = {
"timeout": 3600000,
"url": "https://github.com/databrickslabs/dolly",
"architecture": "GPTNeoXForCausalLM",
"default_id": "databricks/dolly-v2-3b",
"model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"],
}
return_full_text: bool = openllm.LLMConfig.Field(
False, description="Whether to return the full prompt to the users."
)
class GenerationConfig:
temperature: float = 0.9
top_p: float = 0.92
top_k: int = 5
max_new_tokens: int = 256
eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY)
__config__ = {"timeout": 3600000, "url": "https://github.com/databrickslabs/dolly", "architecture": "GPTNeoXForCausalLM", "default_id": "databricks/dolly-v2-3b", "model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"],}
return_full_text: bool = openllm.LLMConfig.Field(False, description="Whether to return the full prompt to the users.")
class GenerationConfig:
temperature: float = 0.9
top_p: float = 0.92
top_k: int = 5
max_new_tokens: int = 256
eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY)
START_DOLLY_V2_COMMAND_DOCSTRING = """\
Run a LLMServer for dolly-v2 model.
@@ -74,8 +69,9 @@ DEFAULT_PROMPT_TEMPLATE = """{intro}
{instruction}
{response_key}
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY)
def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int:
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
@@ -90,6 +86,6 @@ def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str)
Returns:
int: the token ID for the given key.
"""
token_ids = tokenizer.encode(key)
if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
return token_ids[0]
token_ids = tokenizer.encode(key)
if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
return token_ids[0]

View File

@@ -24,104 +24,130 @@ from ..._prompt import process_prompt
if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf
else: torch, transformers, tf = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("tf", globals(), "tensorflow")
logger = logging.getLogger(__name__)
@t.overload
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline: ...
@t.overload
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]: ...
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
class InstructionTextGenerationPipeline(transformers.Pipeline):
def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any):
if t.TYPE_CHECKING: assert self.tokenizer is not None
preprocess_params: dict[str, t.Any] = {}
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
# append a newline to yield a single token. find whatever token is configured for the response key.
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
response_key_token_id = None
end_key_token_id = None
if tokenizer_response_key:
try:
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
# Ensure generation stops once it generates "### End"
generate_kwargs["eos_token_id"] = end_key_token_id
except ValueError: pass
forward_params = generate_kwargs
postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id}
if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text
return preprocess_params, forward_params, postprocess_params
def preprocess(self, input_: str, **generate_kwargs: t.Any):
if t.TYPE_CHECKING: assert self.tokenizer is not None
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
inputs = self.tokenizer(prompt_text, return_tensors="pt")
inputs["prompt_text"] = prompt_text
inputs["instruction_text"] = input_
return inputs
def _forward(self, model_inputs: dict[str, t.Any], **generate_kwargs: t.Any):
if t.TYPE_CHECKING: assert self.tokenizer is not None
input_ids, attention_mask = model_inputs["input_ids"], model_inputs.get("attention_mask", None)
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
else: in_b = input_ids.shape[0]
generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs)
out_b = generated_sequence.shape[0]
if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
instruction_text = model_inputs.pop("instruction_text")
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False):
if t.TYPE_CHECKING: assert self.tokenizer is not None
generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"]
generated_sequence: list[list[int]] = generated_sequence.numpy().tolist()
records: list[dict[t.Literal["generated_text"], str]] = []
for sequence in generated_sequence:
# The response will be set to this variable if we can identify it.
decoded = None
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
if response_key_token_id and end_key_token_id:
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
# prompt, we should definitely find it. We will return the tokens found after this token.
try: response_pos = sequence.index(response_key_token_id)
except ValueError: response_pos = None
if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
if response_pos:
# Next find where "### End" is located. The model has been trained to end its responses with this
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
# this token, as the response could be truncated. If we don't find it then just return everything
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
try: end_pos = sequence.index(end_key_token_id)
except ValueError: end_pos = None
decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
if not decoded:
# Otherwise we'll decode everything and use a regex to find the response and end.
fully_decoded = self.tokenizer.decode(sequence)
# The response appears after "### Response:". The model has been trained to append "### End" at the
# end.
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
if m: decoded = m.group(1).strip()
else:
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
# return everything after "### Response:".
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
if m: decoded = m.group(1).strip()
else: logger.warning("Failed to find response in:\n%s", fully_decoded)
# If the full text is requested, then append the decoded text to the original instruction.
# This technically isn't the full text, as we format the instruction in the prompt the model has been
# trained on, but to the client it will appear to be the full text.
if return_full_text: decoded = f"{instruction_text}\n{decoded}"
rec = {"generated_text": decoded}
records.append(rec)
return records
if _init: return InstructionTextGenerationPipeline()
return InstructionTextGenerationPipeline
@t.overload
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline:
...
@t.overload
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]:
...
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
class InstructionTextGenerationPipeline(transformers.Pipeline):
def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any):
super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any):
if t.TYPE_CHECKING: assert self.tokenizer is not None
preprocess_params: dict[str, t.Any] = {}
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
# append a newline to yield a single token. find whatever token is configured for the response key.
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
response_key_token_id = None
end_key_token_id = None
if tokenizer_response_key:
try:
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
# Ensure generation stops once it generates "### End"
generate_kwargs["eos_token_id"] = end_key_token_id
except ValueError:
pass
forward_params = generate_kwargs
postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id}
if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text
return preprocess_params, forward_params, postprocess_params
def preprocess(self, input_: str, **generate_kwargs: t.Any):
if t.TYPE_CHECKING: assert self.tokenizer is not None
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
inputs = self.tokenizer(prompt_text, return_tensors="pt")
inputs["prompt_text"] = prompt_text
inputs["instruction_text"] = input_
return inputs
def _forward(self, model_inputs: dict[str, t.Any], **generate_kwargs: t.Any):
if t.TYPE_CHECKING: assert self.tokenizer is not None
input_ids, attention_mask = model_inputs["input_ids"], model_inputs.get("attention_mask", None)
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
else: in_b = input_ids.shape[0]
generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs)
out_b = generated_sequence.shape[0]
if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
instruction_text = model_inputs.pop("instruction_text")
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False):
if t.TYPE_CHECKING: assert self.tokenizer is not None
generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"]
generated_sequence: list[list[int]] = generated_sequence.numpy().tolist()
records: list[dict[t.Literal["generated_text"], str]] = []
for sequence in generated_sequence:
# The response will be set to this variable if we can identify it.
decoded = None
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
if response_key_token_id and end_key_token_id:
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
# prompt, we should definitely find it. We will return the tokens found after this token.
try:
response_pos = sequence.index(response_key_token_id)
except ValueError:
response_pos = None
if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
if response_pos:
# Next find where "### End" is located. The model has been trained to end its responses with this
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
# this token, as the response could be truncated. If we don't find it then just return everything
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
try:
end_pos = sequence.index(end_key_token_id)
except ValueError:
end_pos = None
decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip()
if not decoded:
# Otherwise we'll decode everything and use a regex to find the response and end.
fully_decoded = self.tokenizer.decode(sequence)
# The response appears after "### Response:". The model has been trained to append "### End" at the
# end.
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
if m: decoded = m.group(1).strip()
else:
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
# return everything after "### Response:".
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
if m: decoded = m.group(1).strip()
else: logger.warning("Failed to find response in:\n%s", fully_decoded)
# If the full text is requested, then append the decoded text to the original instruction.
# This technically isn't the full text, as we format the instruction in the prompt the model has been
# trained on, but to the client it will appear to be the full text.
if return_full_text: decoded = f"{instruction_text}\n{decoded}"
rec = {"generated_text": decoded}
records.append(rec)
return records
if _init: return InstructionTextGenerationPipeline()
return InstructionTextGenerationPipeline
class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedTokenizer"]):
__openllm_internal__ = True
@property
def import_kwargs(self): return {"device_map": "auto" if torch.cuda.is_available() else None, "torch_dtype": torch.bfloat16}, {"padding_side": "left"}
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return get_pipeline(model=transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), tokenizer=self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str: return generation_result[0]["generated_text"]
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
llm_config = self.config.model_construct_env(**attrs)
with torch.inference_mode(): return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config())
__openllm_internal__ = True
@property
def import_kwargs(self):
return {"device_map": "auto" if torch.cuda.is_available() else None, "torch_dtype": torch.bfloat16}, {"padding_side": "left"}
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline:
return get_pipeline(model=transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), tokenizer=self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str:
return generation_result[0]["generated_text"]
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
llm_config = self.config.model_construct_env(**attrs)
with torch.inference_mode():
return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config())

View File

@@ -17,17 +17,23 @@ import typing as t
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {"configuration_falcon": ["FalconConfig", "START_FALCON_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_falcon"] = ["Falcon"]
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_falcon"] = ["Falcon"]
if t.TYPE_CHECKING:
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_falcon import START_FALCON_COMMAND_DOCSTRING as START_FALCON_COMMAND_DOCSTRING
from .configuration_falcon import FalconConfig as FalconConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_falcon import Falcon as Falcon
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_falcon import START_FALCON_COMMAND_DOCSTRING as START_FALCON_COMMAND_DOCSTRING
from .configuration_falcon import FalconConfig as FalconConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_falcon import Falcon as Falcon
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -13,45 +13,26 @@
# limitations under the License.
from __future__ import annotations
import openllm
class FalconConfig(openllm.LLMConfig):
"""Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora.
"""Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora.
It is made available under the TII Falcon LLM License.
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
"""
__config__ = {
"name_type": "lowercase",
"trust_remote_code": True,
"requires_gpu": True,
"timeout": int(36e6),
"url": "https://falconllm.tii.ae/",
"requirements": ["einops", "xformers"],
"architecture": "FalconForCausalLM",
"default_id": "tiiuae/falcon-7b",
"model_ids": [
"tiiuae/falcon-7b",
"tiiuae/falcon-40b",
"tiiuae/falcon-7b-instruct",
"tiiuae/falcon-40b-instruct",
],
"fine_tune_strategies": (
{
"adapter_type": "lora",
"r": 64,
"lora_alpha": 16,
"lora_dropout": 0.1,
"bias": "none",
"target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
},
),
}
class GenerationConfig:
max_new_tokens: int = 200
top_k: int = 10
num_return_sequences: int = 1
num_beams: int = 4
early_stopping: bool = True
__config__ = {
"name_type": "lowercase", "trust_remote_code": True, "requires_gpu": True, "timeout": int(36e6), "url": "https://falconllm.tii.ae/", "requirements": ["einops", "xformers"], "architecture": "FalconForCausalLM", "default_id": "tiiuae/falcon-7b", "model_ids": ["tiiuae/falcon-7b", "tiiuae/falcon-40b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct",],
"fine_tune_strategies": ({"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none", "target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],},),
}
class GenerationConfig:
max_new_tokens: int = 200
top_k: int = 10
num_return_sequences: int = 1
num_beams: int = 4
early_stopping: bool = True
START_FALCON_COMMAND_DOCSTRING = """\
Run a LLMServer for FalconLM model.

View File

@@ -18,21 +18,31 @@ from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import process_prompt
if t.TYPE_CHECKING: import torch, transformers
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]):
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() else None}, {}
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env( eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True)
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq): result = result[: -len(stop_seq)]
return [{"generated_text": result}]
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() else None}, {}
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True)
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
return [{"generated_text": result}]

View File

@@ -20,33 +20,47 @@ from ...utils import LazyModule
from ...utils import is_flax_available
from ...utils import is_tf_available
from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {"configuration_flan_t5": ["FlanT5Config", "START_FLAN_T5_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_flan_t5"] = ["FlanT5"]
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_flan_t5"] = ["FlanT5"]
try:
if not is_flax_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_flax_flan_t5"] = ["FlaxFlanT5"]
if not is_flax_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_flax_flan_t5"] = ["FlaxFlanT5"]
try:
if not is_tf_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_tf_flan_t5"] = ["TFFlanT5"]
if not is_tf_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_tf_flan_t5"] = ["TFFlanT5"]
if t.TYPE_CHECKING:
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_flan_t5 import START_FLAN_T5_COMMAND_DOCSTRING as START_FLAN_T5_COMMAND_DOCSTRING
from .configuration_flan_t5 import FlanT5Config as FlanT5Config
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_flan_t5 import FlanT5 as FlanT5
try:
if not is_flax_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_flax_flan_t5 import FlaxFlanT5 as FlaxFlanT5
try:
if not is_tf_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_tf_flan_t5 import TFFlanT5 as TFFlanT5
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_flan_t5 import START_FLAN_T5_COMMAND_DOCSTRING as START_FLAN_T5_COMMAND_DOCSTRING
from .configuration_flan_t5 import FlanT5Config as FlanT5Config
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_flan_t5 import FlanT5 as FlanT5
try:
if not is_flax_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_flax_flan_t5 import FlaxFlanT5 as FlaxFlanT5
try:
if not is_tf_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_tf_flan_t5 import TFFlanT5 as TFFlanT5
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -13,32 +13,23 @@
# limitations under the License.
from __future__ import annotations
import openllm
class FlanT5Config(openllm.LLMConfig):
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf).
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf).
It is an enhanced version of T5 that has been finetuned in a mixture of tasks.
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
"""
__config__ = {
"url": "https://huggingface.co/docs/transformers/model_doc/flan-t5",
"default_id": "google/flan-t5-large",
"architecture": "T5ForConditionalGeneration",
"model_ids": [
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
],
"model_type": "seq2seq_lm",
}
class GenerationConfig:
temperature: float = 0.9
max_new_tokens: int = 2048
top_k: int = 50
top_p: float = 0.4
repetition_penalty = 1.0
__config__ = {"url": "https://huggingface.co/docs/transformers/model_doc/flan-t5", "default_id": "google/flan-t5-large", "architecture": "T5ForConditionalGeneration", "model_ids": ["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl",], "model_type": "seq2seq_lm",}
class GenerationConfig:
temperature: float = 0.9
max_new_tokens: int = 2048
top_k: int = 50
top_p: float = 0.4
repetition_penalty = 1.0
START_FLAN_T5_COMMAND_DOCSTRING = """\
Run a LLMServer for FLAN-T5 model.

View File

@@ -19,20 +19,28 @@ from ..._llm import LLMEmbeddings
from ..._prompt import process_prompt
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
embeddings: list[list[float]] = []
num_tokens = 0
for prompt in prompts:
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
with torch.inference_mode():
outputs = self.model(input_ids, decoder_input_ids=input_ids)
data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0)
embeddings.append(data.tolist())
num_tokens += len(input_ids[0])
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode():
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
embeddings: list[list[float]] = []
num_tokens = 0
for prompt in prompts:
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
with torch.inference_mode():
outputs = self.model(input_ids, decoder_input_ids=input_ids)
data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0)
embeddings.append(data.tolist())
num_tokens += len(input_ids[0])
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)

View File

@@ -17,13 +17,18 @@ import openllm
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import process_prompt
if t.TYPE_CHECKING: import transformers # noqa: F401
class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, decoder_start_token_id: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
if decoder_start_token_id is None: decoder_start_token_id = 0
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "decoder_start_token_id": decoder_start_token_id}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
# NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation.
decoder_start_token_id = attrs.pop("decoder_start_token_id", 0)
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="np")["input_ids"], do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), decoder_start_token_id=decoder_start_token_id).sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True)
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, decoder_start_token_id: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
if decoder_start_token_id is None: decoder_start_token_id = 0
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "decoder_start_token_id": decoder_start_token_id}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
# NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation.
decoder_start_token_id = attrs.pop("decoder_start_token_id", 0)
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="np")["input_ids"], do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), decoder_start_token_id=decoder_start_token_id).sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True)

View File

@@ -17,8 +17,15 @@ import openllm
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import process_prompt
if t.TYPE_CHECKING: import transformers # noqa: F401
class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="tf").input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="tf").input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)

View File

@@ -17,17 +17,23 @@ import typing as t
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {"configuration_gpt_neox": ["GPTNeoXConfig", "START_GPT_NEOX_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_gpt_neox"] = ["GPTNeoX"]
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_gpt_neox"] = ["GPTNeoX"]
if t.TYPE_CHECKING:
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_gpt_neox import START_GPT_NEOX_COMMAND_DOCSTRING as START_GPT_NEOX_COMMAND_DOCSTRING
from .configuration_gpt_neox import GPTNeoXConfig as GPTNeoXConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_gpt_neox import GPTNeoX as GPTNeoX
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_gpt_neox import START_GPT_NEOX_COMMAND_DOCSTRING as START_GPT_NEOX_COMMAND_DOCSTRING
from .configuration_gpt_neox import GPTNeoXConfig as GPTNeoXConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_gpt_neox import GPTNeoX as GPTNeoX
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -13,8 +13,9 @@
# limitations under the License.
from __future__ import annotations
import openllm
class GPTNeoXConfig(openllm.LLMConfig):
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license.
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license.
It is, to the best of our knowledge, the largest dense autoregressive model
that has publicly available weights at the time of submission. The training and evaluation code, as well as the model weights,
@@ -28,19 +29,13 @@ class GPTNeoXConfig(openllm.LLMConfig):
Refer to [GPTNeoX's model card](https://huggingface.co/docs/transformers/model_doc/gpt_neox)
for more information.
"""
__config__ = {
"model_name": "gpt_neox",
"start_name": "gpt-neox",
"requires_gpu": True,
"architecture": "GPTNeoXForCausalLM",
"url": "https://github.com/EleutherAI/gpt-neox",
"default_id": "eleutherai/gpt-neox-20b",
"model_ids": ["eleutherai/gpt-neox-20b"],
}
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
class GenerationConfig:
temperature: float = 0.9
max_new_tokens: int = 100
__config__ = {"model_name": "gpt_neox", "start_name": "gpt-neox", "requires_gpu": True, "architecture": "GPTNeoXForCausalLM", "url": "https://github.com/EleutherAI/gpt-neox", "default_id": "eleutherai/gpt-neox-20b", "model_ids": ["eleutherai/gpt-neox-20b"],}
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
class GenerationConfig:
temperature: float = 0.9
max_new_tokens: int = 100
START_GPT_NEOX_COMMAND_DOCSTRING = """\
Run a LLMServer for GPTNeoX model.

View File

@@ -20,15 +20,25 @@ from ..._prompt import process_prompt
if t.TYPE_CHECKING: import torch, transformers
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {}
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {}
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM:
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs)
if self.config.use_half_precision: model.half()
return model
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])))
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {}
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {}
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
return generation_result[0]
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM:
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs)
if self.config.use_half_precision: model.half()
return model
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode():
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])))

View File

@@ -18,26 +18,36 @@ from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
from ...utils import is_vllm_available
_import_structure: dict[str, list[str]] = {"configuration_llama": ["LlamaConfig", "START_LLAMA_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE", "PROMPT_MAPPING"]}
try:
if not is_vllm_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_vllm_llama"] = ["VLLMLlama"]
if not is_vllm_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_vllm_llama"] = ["VLLMLlama"]
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_llama"] = ["Llama"]
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_llama"] = ["Llama"]
if t.TYPE_CHECKING:
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_llama import PROMPT_MAPPING as PROMPT_MAPPING
from .configuration_llama import START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING
from .configuration_llama import LlamaConfig as LlamaConfig
try:
if not is_vllm_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_vllm_llama import VLLMLlama as VLLMLlama
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_llama import Llama as Llama
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_llama import PROMPT_MAPPING as PROMPT_MAPPING
from .configuration_llama import START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING
from .configuration_llama import LlamaConfig as LlamaConfig
try:
if not is_vllm_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_vllm_llama import VLLMLlama as VLLMLlama
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_llama import Llama as Llama
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -14,8 +14,9 @@
from __future__ import annotations
import typing as t
import openllm
class LlamaConfig(openllm.LLMConfig):
"""LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
"""LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
It is a collection of foundation language models ranging from 7B to 65B parameters.
@@ -26,54 +27,24 @@ class LlamaConfig(openllm.LLMConfig):
Refer to [Llama's model card](https://huggingface.co/docs/transformers/main/model_doc/llama)
for more information.
"""
use_llama2_prompt: bool = openllm.LLMConfig.Field(True, description="Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.")
__config__ = {
"name_type": "lowercase",
"url": "https://github.com/facebookresearch/llama",
"default_id": "huggyllama/llama-7b",
"default_implementation": {"cpu": "pt", "nvidia.com/gpu": "pt"},
"architecture": "LlamaForCausalLM",
"requirements": ["fairscale", "sentencepiece"],
"model_ids": [
"meta-llama/Llama-2-70b-chat-hf",
"meta-llama/Llama-2-13b-chat-hf",
"meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-70b-hf",
"meta-llama/Llama-2-13b-hf",
"meta-llama/Llama-2-7b-hf",
"NousResearch/llama-2-70b-chat-hf",
"NousResearch/llama-2-13b-chat-hf",
"NousResearch/llama-2-7b-chat-hf",
"NousResearch/llama-2-70b-hf",
"NousResearch/llama-2-13b-hf",
"NousResearch/llama-2-7b-hf",
"openlm-research/open_llama_7b_v2",
"openlm-research/open_llama_3b_v2",
"openlm-research/open_llama_13b",
"huggyllama/llama-65b",
"huggyllama/llama-30b",
"huggyllama/llama-13b",
"huggyllama/llama-7b",
],
"tokenizer_class": "LlamaTokenizerFast",
"fine_tune_strategies": (
{
"adapter_type": "lora",
"r": 64,
"lora_alpha": 16,
"lora_dropout": 0.1,
"bias": "none",
},
),
}
class GenerationConfig:
max_new_tokens: int = 256
temperature: float = 0.45
top_p: float = 0.95
top_k: int = 12
class SamplingParams:
best_of: int = 1
presence_penalty: float = 0.5
use_llama2_prompt: bool = openllm.LLMConfig.Field(True, description="Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.")
__config__ = {
"name_type": "lowercase", "url": "https://github.com/facebookresearch/llama", "default_id": "huggyllama/llama-7b", "default_implementation": {"cpu": "pt", "nvidia.com/gpu": "pt"}, "architecture": "LlamaForCausalLM", "requirements": ["fairscale", "sentencepiece"], "model_ids": [
"meta-llama/Llama-2-70b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-13b-hf", "meta-llama/Llama-2-7b-hf", "NousResearch/llama-2-70b-chat-hf", "NousResearch/llama-2-13b-chat-hf", "NousResearch/llama-2-7b-chat-hf", "NousResearch/llama-2-70b-hf", "NousResearch/llama-2-13b-hf",
"NousResearch/llama-2-7b-hf", "openlm-research/open_llama_7b_v2", "openlm-research/open_llama_3b_v2", "openlm-research/open_llama_13b", "huggyllama/llama-65b", "huggyllama/llama-30b", "huggyllama/llama-13b", "huggyllama/llama-7b",
], "tokenizer_class": "LlamaTokenizerFast", "fine_tune_strategies": ({"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none",},),
}
class GenerationConfig:
max_new_tokens: int = 256
temperature: float = 0.45
top_p: float = 0.95
top_k: int = 12
class SamplingParams:
best_of: int = 1
presence_penalty: float = 0.5
START_LLAMA_COMMAND_DOCSTRING = """\
Run a LLMServer for Llama model.
@@ -112,5 +83,7 @@ SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = "[INST]", "[/INST]", "<<SY
# TODO: support history and v1 prompt implementation
_v1_prompt, _v2_prompt = """{instruction}""", """{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key} """.format(start_key=SINST_KEY, sys_key=SYS_KEY, system_message=SYSTEM_MESSAGE, instruction="{instruction}", end_key=EINST_KEY)
PROMPT_MAPPING = {"v1": _v1_prompt, "v2": _v2_prompt}
def _get_prompt(model_type: t.Literal["v1", "v2"]) -> str: return PROMPT_MAPPING[model_type]
DEFAULT_PROMPT_TEMPLATE = _get_prompt

View File

@@ -21,22 +21,28 @@ from ..._prompt import process_prompt
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
logger = logging.getLogger(__name__)
class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaTokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
_template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None
return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {}
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), generation_config=self.config.model_construct_env(**attrs).to_generation_config(), stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])), skip_special_tokens=True, clean_up_tokenization_spaces=True)
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device)
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
with torch.inference_mode():
data = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
return LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item())
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
_template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None
return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {}
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), generation_config=self.config.model_construct_env(**attrs).to_generation_config(), stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])), skip_special_tokens=True, clean_up_tokenization_spaces=True)
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device)
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
with torch.inference_mode():
data = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
return LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item())

View File

@@ -19,8 +19,10 @@ from .configuration_llama import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import process_prompt
if t.TYPE_CHECKING: import vllm, transformers
logger = logging.getLogger(__name__)
class VLLMLlama(openllm.LLM["vllm.LLMEngine", "transformers.LlamaTokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
_template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None
return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
_template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None
return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}

View File

@@ -17,18 +17,24 @@ import typing as t
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {"configuration_mpt": ["MPTConfig", "START_MPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE", "PROMPT_MAPPING"]}
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_mpt"] = ["MPT"]
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_mpt"] = ["MPT"]
if t.TYPE_CHECKING:
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_mpt import PROMPT_MAPPING as PROMPT_MAPPING
from .configuration_mpt import START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING
from .configuration_mpt import MPTConfig as MPTConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_mpt import MPT as MPT
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_mpt import PROMPT_MAPPING as PROMPT_MAPPING
from .configuration_mpt import START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING
from .configuration_mpt import MPTConfig as MPTConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_mpt import MPT as MPT
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -16,8 +16,9 @@ import typing as t
import openllm
if t.TYPE_CHECKING: MPTPromptType = t.Literal["default", "instruct", "chat", "storywriter"]
else: MPTPromptType = str
class MPTConfig(openllm.LLMConfig):
"""MPT is a decoder-style transformer pretrained from scratch on English text and code.
"""MPT is a decoder-style transformer pretrained from scratch on English text and code.
This model was trained by [MosaicML](https://www.mosaicml.com/).
@@ -25,30 +26,18 @@ class MPTConfig(openllm.LLMConfig):
on HuggingFace. Refers [HuggingFace's MosaicML page](https://huggingface.co/mosaicml)
for more details on specific models.
"""
__config__ = {
"name_type": "lowercase",
"trust_remote_code": True,
"url": "https://huggingface.co/mosaicml",
"default_id": "mosaicml/mpt-7b-instruct",
"timeout": int(36e6),
"requirements": ["triton", "einops"],
"architecture": "MPTForCausalLM",
"model_ids": [
"mosaicml/mpt-7b",
"mosaicml/mpt-7b-instruct",
"mosaicml/mpt-7b-chat",
"mosaicml/mpt-7b-storywriter",
"mosaicml/mpt-30b",
"mosaicml/mpt-30b-instruct",
"mosaicml/mpt-30b-chat",
],
}
prompt_type: MPTPromptType = openllm.LLMConfig.Field('"default"', description="""Given prompt type for running MPT. Default will be inferred from model name if pretrained.""")
max_sequence_length: int = openllm.LLMConfig.Field(2048, description="Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)")
class GenerationConfig:
max_new_tokens: int = 128
temperature: float = 0
top_p: float = 0.8
__config__ = {
"name_type": "lowercase", "trust_remote_code": True, "url": "https://huggingface.co/mosaicml", "default_id": "mosaicml/mpt-7b-instruct", "timeout": int(36e6), "requirements": ["triton", "einops"], "architecture": "MPTForCausalLM",
"model_ids": ["mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct", "mosaicml/mpt-7b-chat", "mosaicml/mpt-7b-storywriter", "mosaicml/mpt-30b", "mosaicml/mpt-30b-instruct", "mosaicml/mpt-30b-chat",],
}
prompt_type: MPTPromptType = openllm.LLMConfig.Field('"default"', description="""Given prompt type for running MPT. Default will be inferred from model name if pretrained.""")
max_sequence_length: int = openllm.LLMConfig.Field(2048, description="Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)")
class GenerationConfig:
max_new_tokens: int = 128
temperature: float = 0
top_p: float = 0.8
START_MPT_COMMAND_DOCSTRING = """\
Run a LLMServer for MPT model.
@@ -86,5 +75,8 @@ _chat_prompt, _default_prompt, _instruct_prompt = """{instruction}""", """{instr
{response_key}
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY)
PROMPT_MAPPING = {"default": _default_prompt, "instruct": _instruct_prompt, "storywriter": _default_prompt, "chat": _chat_prompt}
def _get_prompt(model_type: str) -> str: return PROMPT_MAPPING[model_type]
def _get_prompt(model_type: str) -> str:
return PROMPT_MAPPING[model_type]
DEFAULT_PROMPT_TEMPLATE = _get_prompt

View File

@@ -23,56 +23,71 @@ from ...utils import generate_labels, is_triton_available
if t.TYPE_CHECKING: import transformers, torch
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
logger = logging.getLogger(__name__)
def get_mpt_config(model_id_or_path: str, max_sequence_length: int, device: torch.device | str | int | None, device_map: str | None = None, trust_remote_code: bool = True) -> transformers.PretrainedConfig:
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
if hasattr(config, "attn_config") and is_triton_available(): config.attn_config["attn_impl"] = "triton"
else: logger.debug("'triton' is not available, Flash Attention will use the default Torch implementation. For faster inference, make sure to install triton with 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'")
# setting max_seq_len
config.max_seq_len = max_sequence_length
return config
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
if hasattr(config, "attn_config") and is_triton_available(): config.attn_config["attn_impl"] = "triton"
else: logger.debug("'triton' is not available, Flash Attention will use the default Torch implementation. For faster inference, make sure to install triton with 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'")
# setting max_seq_len
config.max_seq_len = max_sequence_length
return config
class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXTokenizerFast"]):
__openllm_internal__ = True
def llm_post_init(self): self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"}
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
_, tokenizer_attrs = self.llm_parameters
torch_dtype = attrs.pop("torch_dtype", self.dtype)
device_map = attrs.pop("device_map", None)
attrs.pop("low_cpu_mem_usage", None)
config = get_mpt_config(self.model_id, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map, **attrs)
try: return bentoml.transformers.save_model( self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
finally: torch.cuda.empty_cache()
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel:
torch_dtype = attrs.pop("torch_dtype", self.dtype)
device_map = attrs.pop("device_map", None)
trust_remote_code = attrs.pop("trust_remote_code", True)
config = get_mpt_config(self._bentomodel.path, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code,)
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, config=config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map, **attrs)
model.tie_weights()
return model
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, prompt_type: MPTPromptType | None = None, use_default_prompt_template: bool = True, **attrs: t.Any,) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
_template = None
if use_default_prompt_template:
if prompt_type is None:
if "instruct" in self.model_id: prompt_type = "instruct"
elif "storywriter" in self.model_id: prompt_type = "storywriter"
elif "chat" in self.model_id: prompt_type = "chat"
else: prompt_type = "default"
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
return process_prompt(prompt, _template, use_default_prompt_template), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
llm_config = self.config.model_construct_env(**attrs)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()}
with torch.inference_mode():
if torch.cuda.is_available():
with torch.autocast("cuda", torch.float16):
generated_tensors = self.model.generate(**inputs, **attrs)
else: generated_tensors = self.model.generate(**inputs, **attrs)
return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True)
__openllm_internal__ = True
def llm_post_init(self):
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
return {"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"}
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
_, tokenizer_attrs = self.llm_parameters
torch_dtype = attrs.pop("torch_dtype", self.dtype)
device_map = attrs.pop("device_map", None)
attrs.pop("low_cpu_mem_usage", None)
config = get_mpt_config(self.model_id, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map, **attrs)
try:
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
finally:
torch.cuda.empty_cache()
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel:
torch_dtype = attrs.pop("torch_dtype", self.dtype)
device_map = attrs.pop("device_map", None)
trust_remote_code = attrs.pop("trust_remote_code", True)
config = get_mpt_config(self._bentomodel.path, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code,)
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, config=config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map, **attrs)
model.tie_weights()
return model
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, prompt_type: MPTPromptType | None = None, use_default_prompt_template: bool = True, **attrs: t.Any,) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
_template = None
if use_default_prompt_template:
if prompt_type is None:
if "instruct" in self.model_id: prompt_type = "instruct"
elif "storywriter" in self.model_id: prompt_type = "storywriter"
elif "chat" in self.model_id: prompt_type = "chat"
else: prompt_type = "default"
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
return process_prompt(prompt, _template, use_default_prompt_template), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
llm_config = self.config.model_construct_env(**attrs)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()}
with torch.inference_mode():
if torch.cuda.is_available():
with torch.autocast("cuda", torch.float16):
generated_tensors = self.model.generate(**inputs, **attrs)
else:
generated_tensors = self.model.generate(**inputs, **attrs)
return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True)

View File

@@ -20,41 +20,59 @@ from ...utils import is_flax_available
from ...utils import is_tf_available
from ...utils import is_torch_available
from ...utils import is_vllm_available
_import_structure: dict[str, list[str]] = {"configuration_opt": ["OPTConfig", "START_OPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_opt"] = ["OPT"]
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_opt"] = ["OPT"]
try:
if not is_flax_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_flax_opt"] = ["FlaxOPT"]
if not is_flax_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_flax_opt"] = ["FlaxOPT"]
try:
if not is_vllm_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_vllm_opt"] = ["VLLMOPT"]
if not is_vllm_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_vllm_opt"] = ["VLLMOPT"]
try:
if not is_tf_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_tf_opt"] = ["TFOPT"]
if not is_tf_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_tf_opt"] = ["TFOPT"]
if t.TYPE_CHECKING:
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_opt import START_OPT_COMMAND_DOCSTRING as START_OPT_COMMAND_DOCSTRING
from .configuration_opt import OPTConfig as OPTConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_opt import OPT as OPT
try:
if not is_flax_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_flax_opt import FlaxOPT as FlaxOPT
try:
if not is_vllm_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_vllm_opt import VLLMOPT as VLLMOPT
try:
if not is_tf_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_tf_opt import TFOPT as TFOPT
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_opt import START_OPT_COMMAND_DOCSTRING as START_OPT_COMMAND_DOCSTRING
from .configuration_opt import OPTConfig as OPTConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_opt import OPT as OPT
try:
if not is_flax_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_flax_opt import FlaxOPT as FlaxOPT
try:
if not is_vllm_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_vllm_opt import VLLMOPT as VLLMOPT
try:
if not is_tf_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_tf_opt import TFOPT as TFOPT
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -13,8 +13,9 @@
# limitations under the License.
from __future__ import annotations
import openllm
class OPTConfig(openllm.LLMConfig):
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI.
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI.
OPT was predominantly pretrained with English text, but a small amount of non-English data is still present
within the training corpus via CommonCrawl. The model was pretrained using a causal language modeling (CLM)
@@ -23,37 +24,18 @@ class OPTConfig(openllm.LLMConfig):
Refer to [OPT's HuggingFace page](https://huggingface.co/docs/transformers/model_doc/opt) for more information.
"""
__config__ = {
"name_type": "lowercase",
"trust_remote_code": False,
"url": "https://huggingface.co/docs/transformers/model_doc/opt",
"default_id": "facebook/opt-1.3b",
"architecture": "OPTForCausalLM",
"model_ids": [
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-2.7b",
"facebook/opt-6.7b",
"facebook/opt-66b",
],
"fine_tune_strategies": (
{
"adapter_type": "lora",
"r": 16,
"lora_alpha": 32,
"target_modules": ["q_proj", "v_proj"],
"lora_dropout": 0.05,
"bias": "none",
},
),
}
format_outputs: bool = openllm.LLMConfig.Field(False, description="""Whether to format the outputs. This can be used when num_return_sequences > 1.""")
class GenerationConfig:
top_k: int = 15
temperature: float = 0.75
max_new_tokens: int = 1024
num_return_sequences: int = 1
__config__ = {
"name_type": "lowercase", "trust_remote_code": False, "url": "https://huggingface.co/docs/transformers/model_doc/opt", "default_id": "facebook/opt-1.3b", "architecture": "OPTForCausalLM", "model_ids": ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b", "facebook/opt-66b",],
"fine_tune_strategies": ({"adapter_type": "lora", "r": 16, "lora_alpha": 32, "target_modules": ["q_proj", "v_proj"], "lora_dropout": 0.05, "bias": "none",},),
}
format_outputs: bool = openllm.LLMConfig.Field(False, description="""Whether to format the outputs. This can be used when num_return_sequences > 1.""")
class GenerationConfig:
top_k: int = 15
temperature: float = 0.75
max_new_tokens: int = 1024
num_return_sequences: int = 1
START_OPT_COMMAND_DOCSTRING = """\
Run a LLMServer for OPT model.

View File

@@ -22,18 +22,26 @@ from ...utils import generate_labels
if t.TYPE_CHECKING: import transformers
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {}, {"padding_side": "left", "truncation_side": "left"}
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
tokenizer.pad_token_id = config.pad_token_id
return bentoml.transformers.save_model(self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences, "repetition_penalty": repetition_penalty}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
if len(generation_result) == 1: return generation_result[0]
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
else: return "\n".join(generation_result)
def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode( self.model.generate(**self.tokenizer(prompt, return_tensors="np"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences, skip_special_tokens=True)
class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
return {}, {"padding_side": "left", "truncation_side": "left"}
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
tokenizer.pad_token_id = config.pad_token_id
return bentoml.transformers.save_model(self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences, "repetition_penalty": repetition_penalty}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
if len(generation_result) == 1: return generation_result[0]
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
else: return "\n".join(generation_result)
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="np"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences, skip_special_tokens=True)

View File

@@ -18,23 +18,34 @@ import openllm
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import process_prompt
if t.TYPE_CHECKING:
import torch, transformers
import torch, transformers
else:
torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer"]):
__openllm_internal__ = True
def llm_post_init(self): self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left", "truncation_side": "left"}
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM:
torch_dtype = attrs.pop("torch_dtype", self.dtype)
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, torch_dtype=torch_dtype, **attrs)
return model
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
if len(generation_result) == 1: return generation_result[0]
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
else: return "\n".join(generation_result)
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
__openllm_internal__ = True
def llm_post_init(self):
self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left", "truncation_side": "left"}
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM:
torch_dtype = attrs.pop("torch_dtype", self.dtype)
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, torch_dtype=torch_dtype, **attrs)
return model
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
if len(generation_result) == 1: return generation_result[0]
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
else: return "\n".join(generation_result)
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode():
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)

View File

@@ -22,17 +22,26 @@ from ...utils import generate_labels
if t.TYPE_CHECKING: import transformers
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {}, {"padding_side": "left", "truncation_side": "left"}
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
tokenizer.pad_token_id = config.pad_token_id
return bentoml.transformers.save_model(self.tag, transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
if len(generation_result) == 1: return generation_result[0]
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
else: return "\n".join(generation_result)
def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="tf"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
return {}, {"padding_side": "left", "truncation_side": "left"}
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
tokenizer.pad_token_id = config.pad_token_id
return bentoml.transformers.save_model(self.tag, transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
if len(generation_result) == 1: return generation_result[0]
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
else: return "\n".join(generation_result)
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="tf"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)

View File

@@ -19,9 +19,13 @@ import openllm
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import process_prompt
if t.TYPE_CHECKING:
import vllm, transformers
import vllm, transformers
logger = logging.getLogger(__name__)
class VLLMOPT(openllm.LLM["vllm.LLMEngine", "transformers.GPT2Tokenizer"]):
__openllm_internal__ = True
tokenizer_id = "local"
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
__openllm_internal__ = True
tokenizer_id = "local"
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}

View File

@@ -17,17 +17,23 @@ import typing as t
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {"configuration_stablelm": ["StableLMConfig", "START_STABLELM_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_stablelm"] = ["StableLM"]
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_stablelm"] = ["StableLM"]
if t.TYPE_CHECKING:
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_stablelm import START_STABLELM_COMMAND_DOCSTRING as START_STABLELM_COMMAND_DOCSTRING
from .configuration_stablelm import StableLMConfig as StableLMConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_stablelm import StableLM as StableLM
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_stablelm import START_STABLELM_COMMAND_DOCSTRING as START_STABLELM_COMMAND_DOCSTRING
from .configuration_stablelm import StableLMConfig as StableLMConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_stablelm import StableLM as StableLM
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -13,8 +13,9 @@
# limitations under the License.
from __future__ import annotations
import openllm
class StableLMConfig(openllm.LLMConfig):
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models.
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models.
It is pre-trained on a diverse collection of English datasets with a sequence
length of 4096 to push beyond the context window limitations of existing open-source language models.
@@ -27,23 +28,14 @@ class StableLMConfig(openllm.LLMConfig):
and [StableLM-base's model card](https://huggingface.co/stabilityai/stablelm-base-alpha-7b)
for more information.
"""
__config__ = {
"name_type": "lowercase",
"url": "https://github.com/Stability-AI/StableLM",
"architecture": "GPTNeoXForCausalLM",
"default_id": "stabilityai/stablelm-tuned-alpha-3b",
"model_ids": [
"stabilityai/stablelm-tuned-alpha-3b",
"stabilityai/stablelm-tuned-alpha-7b",
"stabilityai/stablelm-base-alpha-3b",
"stabilityai/stablelm-base-alpha-7b",
],
}
class GenerationConfig:
temperature: float = 0.9
max_new_tokens: int = 128
top_k: int = 0
top_p: float = 0.9
__config__ = {"name_type": "lowercase", "url": "https://github.com/Stability-AI/StableLM", "architecture": "GPTNeoXForCausalLM", "default_id": "stabilityai/stablelm-tuned-alpha-3b", "model_ids": ["stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", "stabilityai/stablelm-base-alpha-3b", "stabilityai/stablelm-base-alpha-7b",],}
class GenerationConfig:
temperature: float = 0.9
max_new_tokens: int = 128
top_k: int = 0
top_p: float = 0.9
START_STABLELM_COMMAND_DOCSTRING = """\
Run a LLMServer for StableLM model.

View File

@@ -21,17 +21,28 @@ from ..._prompt import process_prompt
if t.TYPE_CHECKING: import transformers, torch
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
logger = logging.getLogger(__name__)
class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
__openllm_internal__ = True
def llm_post_init(self): self.bettertransformer = True if not torch.cuda.is_available() else False
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
if "tuned" in self._model_id and use_default_prompt_template:
system_prompt = attrs.pop("system_prompt", SYSTEM_PROMPT)
prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs)
else: prompt_text = prompt
return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {}
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode(): return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))[0], skip_special_tokens=True)]
__openllm_internal__ = True
def llm_post_init(self):
self.bettertransformer = True if not torch.cuda.is_available() else False
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
if "tuned" in self._model_id and use_default_prompt_template:
system_prompt = attrs.pop("system_prompt", SYSTEM_PROMPT)
prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs)
else:
prompt_text = prompt
return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {}
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode():
return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))[0], skip_special_tokens=True)]

View File

@@ -17,17 +17,23 @@ import typing as t
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {"configuration_starcoder": ["StarCoderConfig", "START_STARCODER_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: _import_structure["modeling_starcoder"] = ["StarCoder"]
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
_import_structure["modeling_starcoder"] = ["StarCoder"]
if t.TYPE_CHECKING:
from .configuration_starcoder import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_starcoder import START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING
from .configuration_starcoder import StarCoderConfig as StarCoderConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError: pass
else: from .modeling_starcoder import StarCoder as StarCoder
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
from .configuration_starcoder import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_starcoder import START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING
from .configuration_starcoder import StarCoderConfig as StarCoderConfig
try:
if not is_torch_available(): raise MissingDependencyError
except MissingDependencyError:
pass
else:
from .modeling_starcoder import StarCoder as StarCoder
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -13,8 +13,9 @@
# limitations under the License.
from __future__ import annotations
import openllm
class StarCoderConfig(openllm.LLMConfig):
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
The model uses [Multi Query Attention](https://arxiv.org/abs/1911.02150),
[a context window of 8192 tokens](https://arxiv.org/abs/2205.14135), and was trained using the
@@ -22,24 +23,17 @@ class StarCoderConfig(openllm.LLMConfig):
Refer to [StarCoder's model card](https://huggingface.co/bigcode/starcoder) for more information.
"""
__config__ = {
"name_type": "lowercase",
"requires_gpu": True,
"url": "https://github.com/bigcode-project/starcoder",
"architecture": "GPTBigCodeForCausalLM",
"requirements": ["bitsandbytes"],
"workers_per_resource": 0.5,
"default_id": "bigcode/starcoder",
"model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"],
}
class GenerationConfig:
temperature: float = 0.2
max_new_tokens: int = 256
min_new_tokens: int = 32
top_k: float = 50
top_p: float = 0.95
pad_token_id: int = 49152
repetition_penalty: float = 1.2
__config__ = {"name_type": "lowercase", "requires_gpu": True, "url": "https://github.com/bigcode-project/starcoder", "architecture": "GPTBigCodeForCausalLM", "requirements": ["bitsandbytes"], "workers_per_resource": 0.5, "default_id": "bigcode/starcoder", "model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"],}
class GenerationConfig:
temperature: float = 0.2
max_new_tokens: int = 256
min_new_tokens: int = 32
top_k: float = 50
top_p: float = 0.95
pad_token_id: int = 49152
repetition_penalty: float = 1.2
START_STARCODER_COMMAND_DOCSTRING = """\
Run a LLMServer for StarCoder model.

View File

@@ -18,47 +18,61 @@ import bentoml
import openllm
from ...utils import generate_labels
if t.TYPE_CHECKING:
import torch, transformers
import torch, transformers
else:
torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = "<fim-prefix>", "<fim-middle>", "<fim-suffix>", "<fim-pad>", "<|endoftext|>", "<FILL_HERE>"
class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]):
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"}
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD})
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
try: return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
finally: torch.cuda.empty_cache()
def sanitize_parameters(self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
if fim_mode:
try: prefix, suffix = prompt.split(FIM_INDICATOR)
except Exception as err: raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
else: prompt_text = prompt
# XXX: This value for pad_token_id is currently a hack, need more investigate why the
# default starcoder doesn't include the same value as santacoder EOD
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode():
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
# NOTE: support fine-tuning starcoder
result_tensor = self.model.generate(self.tokenizer.encode(prompt, return_tensors="pt").to(self.device), do_sample=True, pad_token_id=self.tokenizer.eos_token_id, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
# TODO: We will probably want to return the tokenizer here so that we can manually process this
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq): result = result[: -len(stop_seq)]
return [{"generated_text": result}]
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"}
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD})
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
try:
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
finally:
torch.cuda.empty_cache()
def sanitize_parameters(self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
if fim_mode:
try:
prefix, suffix = prompt.split(FIM_INDICATOR)
except Exception as err:
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
else:
prompt_text = prompt
# XXX: This value for pad_token_id is currently a hack, need more investigate why the
# default starcoder doesn't include the same value as santacoder EOD
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode():
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
# NOTE: support fine-tuning starcoder
result_tensor = self.model.generate(self.tokenizer.encode(prompt, return_tensors="pt").to(self.device), do_sample=True, pad_token_id=self.tokenizer.eos_token_id, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
# TODO: We will probably want to return the tokenizer here so that we can manually process this
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
return [{"generated_text": result}]

View File

@@ -11,7 +11,6 @@ import torch
import openllm
import transformers
# Make sure to have at least one GPU to run this script
openllm.utils.configure_logging()
@@ -21,90 +20,54 @@ logger = logging.getLogger(__name__)
# On notebook, make sure to install the following
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
from datasets import load_dataset
from trl import SFTTrainer
DEFAULT_MODEL_ID = "ybelkada/falcon-7b-sharded-bf16"
DATASET_NAME = "timdettmers/openassistant-guanaco"
@dataclasses.dataclass
class TrainingArguments:
per_device_train_batch_size: int = dataclasses.field(default=4)
gradient_accumulation_steps: int = dataclasses.field(default=4)
optim: str = dataclasses.field(default="paged_adamw_32bit")
save_steps: int = dataclasses.field(default=10)
warmup_steps: int = dataclasses.field(default=10)
max_steps: int = dataclasses.field(default=500)
logging_steps: int = dataclasses.field(default=10)
learning_rate: float = dataclasses.field(default=2e-4)
max_grad_norm: float = dataclasses.field(default=0.3)
warmup_ratio: float = dataclasses.field(default=0.03)
fp16: bool = dataclasses.field(default=True)
group_by_length: bool = dataclasses.field(default=True)
lr_scheduler_type: str = dataclasses.field(default="constant")
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "falcon"))
per_device_train_batch_size: int = dataclasses.field(default=4)
gradient_accumulation_steps: int = dataclasses.field(default=4)
optim: str = dataclasses.field(default="paged_adamw_32bit")
save_steps: int = dataclasses.field(default=10)
warmup_steps: int = dataclasses.field(default=10)
max_steps: int = dataclasses.field(default=500)
logging_steps: int = dataclasses.field(default=10)
learning_rate: float = dataclasses.field(default=2e-4)
max_grad_norm: float = dataclasses.field(default=0.3)
warmup_ratio: float = dataclasses.field(default=0.03)
fp16: bool = dataclasses.field(default=True)
group_by_length: bool = dataclasses.field(default=True)
lr_scheduler_type: str = dataclasses.field(default="constant")
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "falcon"))
@dataclasses.dataclass
class ModelArguments:
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
max_sequence_length: int = dataclasses.field(default=512)
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
max_sequence_length: int = dataclasses.field(default=512)
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args = t.cast(
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
)
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
model, tokenizer = openllm.AutoLLM.for_model(
"falcon",
model_id=model_args.model_id,
quantize="int4",
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
ensure_available=True,
).prepare_for_training(
adapter_type="lora",
lora_alpha=16,
lora_dropout=0.1,
r=16,
bias="none",
target_modules=[
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
],
)
model, tokenizer = openllm.AutoLLM.for_model("falcon", model_id=model_args.model_id, quantize="int4", bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ensure_available=True,).prepare_for_training(adapter_type="lora", lora_alpha=16, lora_dropout=0.1, r=16, bias="none", target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h",],)
model.config.use_cache = False
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset(DATASET_NAME, split="train")
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=model_args.max_sequence_length,
tokenizer=tokenizer,
args=dataclasses.replace(
transformers.TrainingArguments(training_args.output_dir),
**dataclasses.asdict(training_args),
),
)
trainer = SFTTrainer(model=model, train_dataset=dataset, dataset_text_field="text", max_seq_length=model_args.max_sequence_length, tokenizer=tokenizer, args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args),),)
# upcast layernorm in float32 for more stable training
for name, module in trainer.model.named_modules():
if "norm" in name:
module = module.to(torch.float32)
if "norm" in name:
module = module.to(torch.float32)
trainer.train()

View File

@@ -5,7 +5,6 @@ import typing as t
import openllm
openllm.utils.configure_logging()
logger = logging.getLogger(__name__)
@@ -15,45 +14,42 @@ MAX_NEW_TOKENS = 384
Q = "Answer the following question, step by step:\n{q}\nA:"
question = "What is the meaning of life?"
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("question", default=question)
parser = argparse.ArgumentParser()
parser.add_argument("question", default=question)
if openllm.utils.in_notebook():
args = parser.parse_args(args=[question])
else:
args = parser.parse_args()
if openllm.utils.in_notebook():
args = parser.parse_args(args=[question])
else:
args = parser.parse_args()
model = openllm.AutoLLM.for_model("opt", model_id="facebook/opt-2.7b", ensure_available=True)
prompt = Q.format(q=args.question)
model = openllm.AutoLLM.for_model("opt", model_id="facebook/opt-2.7b", ensure_available=True)
prompt = Q.format(q=args.question)
logger.info("-" * 50, "Running with 'generate()'", "-" * 50)
res = model.generate(prompt, max_new_tokens=MAX_NEW_TOKENS)
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
logger.info("-" * 50, "Running with 'generate()'", "-" * 50)
res = model.generate(prompt, max_new_tokens=MAX_NEW_TOKENS)
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
logger.info("-" * 50, "Running with 'generate()' with per-requests argument", "-" * 50)
res = model.generate(prompt, num_return_sequences=3)
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
logger.info("-" * 50, "Running with 'generate()' with per-requests argument", "-" * 50)
res = model.generate(prompt, num_return_sequences=3)
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
logger.info("-" * 50, "Using Runner abstraction with runner.generate.run()", "-" * 50)
r = openllm.Runner("opt", model_id="facebook/opt-350m", init_local=True)
res = r.generate.run(prompt)
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
logger.info("-" * 50, "Using Runner abstraction with runner.generate.run()", "-" * 50)
r = openllm.Runner("opt", model_id="facebook/opt-350m", init_local=True)
res = r.generate.run(prompt)
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
logger.info("-" * 50, "Using Runner abstraction with runner()", "-" * 50)
res = r(prompt)
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
return 0
logger.info("-" * 50, "Using Runner abstraction with runner()", "-" * 50)
res = r(prompt)
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
return 0
def _mp_fn(index: t.Any): # noqa # type: ignore
# For xla_spawn (TPUs)
main()
# For xla_spawn (TPUs)
main()
if openllm.utils.in_notebook():
main()
main()
else:
raise SystemExit(main())
raise SystemExit(main())

View File

@@ -11,7 +11,7 @@ import torch
import transformers
if t.TYPE_CHECKING:
import peft
import peft
# Make sure to have at least one GPU to run this script
@@ -29,19 +29,17 @@ from itertools import chain
from functools import partial
from random import randrange
# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
def find_all_linear_names(model):
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
# Change this to the local converted path if you don't have access to the meta-llama model
DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf"
@@ -49,202 +47,149 @@ DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf"
DEFAULT_MODEL_VERSION = "335a02887eb6684d487240bbc28b5699298c3135"
DATASET_NAME = "databricks/databricks-dolly-15k"
def format_dolly(sample):
instruction = f"### Instruction\n{sample['instruction']}"
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
response = f"### Answer\n{sample['response']}"
# join all the parts together
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
return prompt
instruction = f"### Instruction\n{sample['instruction']}"
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
response = f"### Answer\n{sample['response']}"
# join all the parts together
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
return prompt
# template dataset to add prompt to each sample
def template_dataset(sample, tokenizer):
sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
return sample
sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
return sample
# empty list to save remainder from batches to use in next batch
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}
def chunk(sample, chunk_length=2048):
# define global remainder variable to save remainder from batches to use in next batch
global remainder
# Concatenate all texts and add remainder from previous batch
concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
# get total number of tokens for batch
batch_total_length = len(concatenated_examples[list(sample.keys())[0]])
# define global remainder variable to save remainder from batches to use in next batch
global remainder
# Concatenate all texts and add remainder from previous batch
concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
# get total number of tokens for batch
batch_total_length = len(concatenated_examples[list(sample.keys())[0]])
# get max number of chunks for batch
if batch_total_length >= chunk_length:
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
# Split by chunks of max_len.
result = {
k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
for k, t in concatenated_examples.items()
}
# add remainder to global variable for next batch
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
# prepare labels
result["labels"] = result["input_ids"].copy()
return result
# get max number of chunks for batch
if batch_total_length >= chunk_length:
batch_chunk_length = (batch_total_length//chunk_length) * chunk_length
# Split by chunks of max_len.
result = {k: [t[i:i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] for k, t in concatenated_examples.items()}
# add remainder to global variable for next batch
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
# prepare labels
result["labels"] = result["input_ids"].copy()
return result
def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
# Load dataset from the hub
dataset = load_dataset(dataset_name, split="train")
# Load dataset from the hub
dataset = load_dataset(dataset_name, split="train")
print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])
print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])
# apply prompt template per sample
dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features))
# print random sample
print("Sample from dolly-v2 ds:", dataset[randint(0, len(dataset))]["text"])
# apply prompt template per sample
dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features))
# print random sample
print("Sample from dolly-v2 ds:", dataset[randint(0, len(dataset))]["text"])
# tokenize and chunk dataset
lm_dataset = dataset.map(
lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)
).map(
partial(chunk, chunk_length=2048),
batched=True,
)
# Print total number of samples
print(f"Total number of samples: {len(lm_dataset)}")
return lm_dataset
# tokenize and chunk dataset
lm_dataset = dataset.map(lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)).map(partial(chunk, chunk_length=2048), batched=True,)
# Print total number of samples
print(f"Total number of samples: {len(lm_dataset)}")
return lm_dataset
@openllm.utils.requires_dependencies("peft", extra="fine-tune")
def prepare_for_int4_training(
model_id: str,
model_version: str | None = None,
gradient_checkpointing: bool = True,
bf16: bool = True,
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
from peft.tuners.lora import LoraLayer
def prepare_for_int4_training(model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True,) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
from peft.tuners.lora import LoraLayer
llm = openllm.AutoLLM.for_model(
"llama",
model_id=model_id,
model_version=model_version,
ensure_available=True,
quantize="int4",
bnb_4bit_compute_dtype=torch.bfloat16,
use_cache=not gradient_checkpointing,
device_map="auto",
)
print("Model summary:", llm.model)
llm = openllm.AutoLLM.for_model("llama", model_id=model_id, model_version=model_version, ensure_available=True, quantize="int4", bnb_4bit_compute_dtype=torch.bfloat16, use_cache=not gradient_checkpointing, device_map="auto",)
print("Model summary:", llm.model)
# get lora target modules
modules = find_all_linear_names(llm.model)
print(f"Found {len(modules)} modules to quantize: {modules}")
# get lora target modules
modules = find_all_linear_names(llm.model)
print(f"Found {len(modules)} modules to quantize: {modules}")
model, tokenizer = llm.prepare_for_training(
adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules
)
# pre-process the model by upcasting the layer norms in float 32 for
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if bf16:
module = module.to(torch.bfloat16)
if "norm" in name:
module = module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
if bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
return model, tokenizer
model, tokenizer = llm.prepare_for_training(adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
# pre-process the model by upcasting the layer norms in float 32 for
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if bf16:
module = module.to(torch.bfloat16)
if "norm" in name:
module = module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
if bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
return model, tokenizer
@dataclasses.dataclass
class TrainingArguments:
per_device_train_batch_size: int = dataclasses.field(default=1)
gradient_checkpointing: bool = dataclasses.field(default=True)
bf16: bool = dataclasses.field(default=torch.cuda.get_device_capability()[0] == 8)
learning_rate: float = dataclasses.field(default=5e-5)
num_train_epochs: int = dataclasses.field(default=3)
logging_steps: int = dataclasses.field(default=1)
report_to: str = dataclasses.field(default="none")
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "llama"))
save_strategy: str = dataclasses.field(default="no")
per_device_train_batch_size: int = dataclasses.field(default=1)
gradient_checkpointing: bool = dataclasses.field(default=True)
bf16: bool = dataclasses.field(default=torch.cuda.get_device_capability()[0] == 8)
learning_rate: float = dataclasses.field(default=5e-5)
num_train_epochs: int = dataclasses.field(default=3)
logging_steps: int = dataclasses.field(default=1)
report_to: str = dataclasses.field(default="none")
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "llama"))
save_strategy: str = dataclasses.field(default="no")
@dataclasses.dataclass
class ModelArguments:
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION)
seed: int = dataclasses.field(default=42)
merge_weights: bool = dataclasses.field(default=False)
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION)
seed: int = dataclasses.field(default=42)
merge_weights: bool = dataclasses.field(default=False)
if openllm.utils.in_notebook():
model_args, training_rags = ModelArguments(), TrainingArguments()
model_args, training_rags = ModelArguments(), TrainingArguments()
else:
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args = t.cast(
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
)
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
# import the model first hand
openllm.import_model("llama", model_id=model_args.model_id, model_version=model_args.model_version)
def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
import peft
import peft
transformers.set_seed(model_args.seed)
transformers.set_seed(model_args.seed)
model, tokenizer = prepare_for_int4_training(
model_args.model_id,
gradient_checkpointing=training_args.gradient_checkpointing,
bf16=training_args.bf16,
)
datasets = prepare_datasets(tokenizer)
model, tokenizer = prepare_for_int4_training(model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16,)
datasets = prepare_datasets(tokenizer)
trainer = transformers.Trainer(
model=model,
args=dataclasses.replace(
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
),
train_dataset=datasets,
data_collator=transformers.default_data_collator,
)
trainer = transformers.Trainer(model=model, args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)), train_dataset=datasets, data_collator=transformers.default_data_collator,)
trainer.train()
trainer.train()
if model_args.merge_weights:
# note that this will requires larger GPU as we will load the whole model into memory
if model_args.merge_weights:
# note that this will requires larger GPU as we will load the whole model into memory
# merge adapter weights with base model and save
# save int4 model
trainer.model.save_pretrained(training_args.output_dir, safe_serialization=False)
# merge adapter weights with base model and save
# save int4 model
trainer.model.save_pretrained(training_args.output_dir, safe_serialization=False)
# gc mem
del model, trainer
torch.cuda.empty_cache()
model = peft.AutoPeftModelForCausalLM.from_pretrained(
training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16
)
# merge lora with base weights and save
model = model.merge_and_unload()
model.save_pretrained(
os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB"
)
else:
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
# gc mem
del model, trainer
torch.cuda.empty_cache()
model = peft.AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16)
# merge lora with base weights and save
model = model.merge_and_unload()
model.save_pretrained(os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB")
else:
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
train_loop(model_args, training_args)

View File

@@ -20,71 +20,38 @@ logger = logging.getLogger(__name__)
from datasets import load_dataset
if t.TYPE_CHECKING:
from peft import PeftModel
from peft import PeftModel
DEFAULT_MODEL_ID = "facebook/opt-6.7b"
def load_trainer(
model: PeftModel,
tokenizer: transformers.GPT2TokenizerFast,
dataset_dict: t.Any,
training_args: TrainingArguments,
):
return transformers.Trainer(
model=model,
train_dataset=dataset_dict["train"],
args=dataclasses.replace(
transformers.TrainingArguments(training_args.output_dir),
**dataclasses.asdict(training_args),
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments,):
return transformers.Trainer(model=model, train_dataset=dataset_dict["train"], args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args),), data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),)
@dataclasses.dataclass
class TrainingArguments:
per_device_train_batch_size: int = dataclasses.field(default=4)
gradient_accumulation_steps: int = dataclasses.field(default=4)
warmup_steps: int = dataclasses.field(default=10)
max_steps: int = dataclasses.field(default=50)
learning_rate: float = dataclasses.field(default=3e-4)
fp16: bool = dataclasses.field(default=True)
logging_steps: int = dataclasses.field(default=1)
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "opt"))
per_device_train_batch_size: int = dataclasses.field(default=4)
gradient_accumulation_steps: int = dataclasses.field(default=4)
warmup_steps: int = dataclasses.field(default=10)
max_steps: int = dataclasses.field(default=50)
learning_rate: float = dataclasses.field(default=3e-4)
fp16: bool = dataclasses.field(default=True)
logging_steps: int = dataclasses.field(default=1)
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "opt"))
@dataclasses.dataclass
class ModelArguments:
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args = t.cast(
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
)
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
model, tokenizer = openllm.AutoLLM.for_model(
"opt",
model_id=model_args.model_id,
quantize="int8",
ensure_available=True,
).prepare_for_training(
adapter_type="lora",
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
)
model, tokenizer = openllm.AutoLLM.for_model("opt", model_id=model_args.model_id, quantize="int8", ensure_available=True,).prepare_for_training(adapter_type="lora", r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none",)
# ft on english_quotes
data = load_dataset("Abirate/english_quotes")

View File

@@ -49,104 +49,80 @@ from ..utils import LazyLoader
from ..utils import LazyModule
if t.TYPE_CHECKING:
import bentoml
import transformers
import bentoml
import transformers
from .._llm import M
from .._llm import T
from .._llm import M
from .._llm import T
else:
transformers = LazyLoader("transformers", globals(), "transformers")
transformers = LazyLoader("transformers", globals(), "transformers")
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model:
if llm.runtime == "transformers":
return openllm.transformers.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs)
elif llm.runtime == "ggml":
return openllm.ggml.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs)
else:
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
if llm.runtime == "transformers":
return openllm.transformers.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs)
elif llm.runtime == "ggml":
return openllm.ggml.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs)
else:
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
if llm.runtime == "transformers":
return openllm.transformers.get(llm, auto_import=auto_import)
elif llm.runtime == "ggml":
return openllm.ggml.get(llm, auto_import=auto_import)
else:
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
if llm.runtime == "transformers":
return openllm.transformers.get(llm, auto_import=auto_import)
elif llm.runtime == "ggml":
return openllm.ggml.get(llm, auto_import=auto_import)
else:
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
def save_pretrained(llm: openllm.LLM[M, T], save_directory: str, **attrs: t.Any) -> None:
if llm.runtime == "transformers":
return openllm.transformers.save_pretrained(llm, save_directory, **attrs)
elif llm.runtime == "ggml":
return openllm.ggml.save_pretrained(llm, save_directory, **attrs)
else:
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
if llm.runtime == "transformers":
return openllm.transformers.save_pretrained(llm, save_directory, **attrs)
elif llm.runtime == "ggml":
return openllm.ggml.save_pretrained(llm, save_directory, **attrs)
else:
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
if llm.runtime == "transformers":
return openllm.transformers.load_model(llm, *decls, **attrs)
elif llm.runtime == "ggml":
return openllm.ggml.load_model(llm, *decls, **attrs)
else:
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
if llm.runtime == "transformers":
return openllm.transformers.load_model(llm, *decls, **attrs)
elif llm.runtime == "ggml":
return openllm.ggml.load_model(llm, *decls, **attrs)
else:
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
"""Load the tokenizer from BentoML store.
"""Load the tokenizer from BentoML store.
By default, it will try to find the bentomodel whether it is in store..
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
from .transformers import infer_tokenizers_class_for_llm
By default, it will try to find the bentomodel whether it is in store..
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
from .transformers import infer_tokenizers_class_for_llm
bentomodel_fs = llm._bentomodel._fs
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile:
try:
tokenizer = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"]
except KeyError:
# This could happen if users implement their own import_model
raise OpenLLMException(
"Model does not have tokenizer. Make sure to save \
bentomodel_fs = llm._bentomodel._fs
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile:
try:
tokenizer = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"]
except KeyError:
# This could happen if users implement their own import_model
raise OpenLLMException("Model does not have tokenizer. Make sure to save \
the tokenizer within the model via 'custom_objects'.\
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))"
) from None
else:
tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(
bentomodel_fs.getsyspath("/"),
trust_remote_code=llm.__llm_trust_remote_code__,
**tokenizer_attrs,
)
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))") from None
else:
tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(bentomodel_fs.getsyspath("/"), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs,)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
return tokenizer
_extras = {
"get": get,
"import_model": import_model,
"save_pretrained": save_pretrained,
"load_model": load_model,
"load_tokenizer": load_tokenizer,
}
_extras = {"get": get, "import_model": import_model, "save_pretrained": save_pretrained, "load_model": load_model, "load_tokenizer": load_tokenizer,}
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": []}
if t.TYPE_CHECKING:
from . import ggml as ggml
from . import transformers as transformers
from . import ggml as ggml
from . import transformers as transformers
else:
import sys
import sys
sys.modules[__name__] = LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
extra_objects=_extras,
)
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects=_extras,)

View File

@@ -17,21 +17,8 @@ from __future__ import annotations
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
# NOTE: vllm will use PyTorch implementation of transformers for serialisation
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
"tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"),
"flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
}
# this logic below is synonymous to handling `from_pretrained` attrs.
HUB_ATTRS = [
"cache_dir",
"code_revision",
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"subfolder",
"use_auth_token",
]
HUB_ATTRS = ["cache_dir", "code_revision", "force_download", "local_files_only", "proxies", "resume_download", "revision", "subfolder", "use_auth_token",]

View File

@@ -23,54 +23,42 @@ import bentoml
from ..exceptions import OpenLLMException
if t.TYPE_CHECKING:
import openllm
import openllm
from .._llm import M
from .._llm import M
_conversion_strategy = {"pt": "ggml"}
def import_model(
llm: openllm.LLM[t.Any, t.Any],
*decls: t.Any,
trust_remote_code: bool = True,
**attrs: t.Any,
) -> bentoml.Model:
raise NotImplementedError("Currently work in progress.")
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
raise NotImplementedError("Currently work in progress.")
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
"""Return an instance of ``bentoml.Model`` from given LLM instance.
"""Return an instance of ``bentoml.Model`` from given LLM instance.
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
"""
try:
model = bentoml.models.get(llm.tag)
if model.info.module not in ("openllm.serialisation.ggml", __name__):
raise bentoml.exceptions.NotFound(
f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'."
)
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
raise OpenLLMException(
f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}."
)
return model
except bentoml.exceptions.NotFound:
if auto_import:
return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
raise
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
"""
try:
model = bentoml.models.get(llm.tag)
if model.info.module not in ("openllm.serialisation.ggml", __name__):
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
return model
except bentoml.exceptions.NotFound:
if auto_import:
return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
raise
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
"""Load the model from BentoML store.
By default, it will try to find check the model in the local store.
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
raise NotImplementedError("Currently work in progress.")
"""Load the model from BentoML store.
By default, it will try to find check the model in the local store.
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
raise NotImplementedError("Currently work in progress.")
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None:
raise NotImplementedError("Currently work in progress.")
raise NotImplementedError("Currently work in progress.")

View File

@@ -38,237 +38,183 @@ from ..utils import lenient_issubclass
from ..utils import normalize_attrs_to_model_tokenizer_pair
if t.TYPE_CHECKING:
import auto_gptq as autogptq
import torch
import torch.cuda
import vllm
import auto_gptq as autogptq
import torch
import torch.cuda
import vllm
import openllm
import transformers as _transformers
from transformers.models.auto.auto_factory import _BaseAutoModelClass
import openllm
import transformers as _transformers
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from .._llm import M
from .._llm import T
from .._types import DictStrAny
from .._llm import M
from .._llm import T
from .._types import DictStrAny
else:
vllm = LazyLoader("vllm", globals(), "vllm")
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
_transformers = LazyLoader("_transformers", globals(), "transformers")
torch = LazyLoader("torch", globals(), "torch")
torch.cuda = LazyLoader("torch.cuda", globals(), "torch.cuda")
vllm = LazyLoader("vllm", globals(), "vllm")
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
_transformers = LazyLoader("_transformers", globals(), "transformers")
torch = LazyLoader("torch", globals(), "torch")
torch.cuda = LazyLoader("torch.cuda", globals(), "torch.cuda")
_object_setattr = object.__setattr__
def process_transformers_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[_transformers.PretrainedConfig, dict[str, t.Any], dict[str, t.Any]]:
"""Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs."""
config = attrs.pop("config", None)
# this logic below is synonymous to handling `from_pretrained` attrs.
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
if not isinstance(config, _transformers.PretrainedConfig):
copied_attrs = copy.deepcopy(attrs)
if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype")
config, attrs = _transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
return config, hub_attrs, attrs
"""Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs."""
config = attrs.pop("config", None)
# this logic below is synonymous to handling `from_pretrained` attrs.
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
if not isinstance(config, _transformers.PretrainedConfig):
copied_attrs = copy.deepcopy(attrs)
if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype")
config, attrs = _transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
return config, hub_attrs, attrs
def infer_tokenizers_class_for_llm(__llm: openllm.LLM[t.Any, T]) -> T:
tokenizer_class = __llm.config["tokenizer_class"]
if tokenizer_class is None: tokenizer_class = "AutoTokenizer"
__cls = getattr(_transformers, tokenizer_class)
if __cls is None: raise ValueError(f"{tokenizer_class} is not a valid Tokenizer class from 'transformers.' Set '{__llm}.__config__[\"trust_remote_code\"] = True' and try again.")
return __cls
tokenizer_class = __llm.config["tokenizer_class"]
if tokenizer_class is None: tokenizer_class = "AutoTokenizer"
__cls = getattr(_transformers, tokenizer_class)
if __cls is None: raise ValueError(f"{tokenizer_class} is not a valid Tokenizer class from 'transformers.' Set '{__llm}.__config__[\"trust_remote_code\"] = True' and try again.")
return __cls
def infer_autoclass_from_llm_config(llm: openllm.LLM[M, T], config: _transformers.PretrainedConfig) -> _BaseAutoModelClass:
if llm.config["trust_remote_code"]:
autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM"
if not hasattr(config, "auto_map"): raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping")
# in case this model doesn't use the correct auto class for model type, for example like chatglm
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
if autoclass not in config.auto_map: autoclass = "AutoModel"
return getattr(_transformers, autoclass)
else:
if type(config) in _transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
elif type(config) in _transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
else: raise OpenLLMException(f"Model type {type(config)} is not supported yet.")
return getattr(_transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx])
if llm.config["trust_remote_code"]:
autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM"
if not hasattr(config, "auto_map"): raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping")
# in case this model doesn't use the correct auto class for model type, for example like chatglm
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
if autoclass not in config.auto_map: autoclass = "AutoModel"
return getattr(_transformers, autoclass)
else:
if type(config) in _transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
elif type(config) in _transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
else: raise OpenLLMException(f"Model type {type(config)} is not supported yet.")
return getattr(_transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx])
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model:
"""Auto detect model type from given model_id and import it to bentoml's model store.
"""Auto detect model type from given model_id and import it to bentoml's model store.
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
returning all of the unused kwargs.
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
returning all of the unused kwargs.
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
Refer to Transformers documentation for more information about kwargs.
Refer to Transformers documentation for more information about kwargs.
Args:
llm: The LLM instance for this given model.
trust_remote_code: Whether to trust the remote code when loading the model.
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
"""
config, hub_attrs, attrs = process_transformers_config(llm.model_id, trust_remote_code, **attrs)
_, tokenizer_attrs = llm.llm_parameters
quantize_method = llm._quantize_method
safe_serialisation = first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors")
# Disable safe serialization with vLLM
if llm.__llm_implementation__ == "vllm": safe_serialisation = False
metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method if quantize_method is not None else False}
signatures: DictStrAny = {}
if quantize_method == "gptq":
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
model = autogptq.AutoGPTQForCausalLM.from_quantized(
llm.model_id,
*decls,
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
trust_remote_code=trust_remote_code,
use_safetensors=safe_serialisation,
**hub_attrs,
**attrs,
)
metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework})
signatures["generate"] = {"batchable": False}
else:
# this model might be called with --quantize int4, therefore we need to pop this out
# since saving int4 is not yet supported
if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False): attrs.pop("quantization_config")
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(
llm.model_id,
*decls,
config=config,
trust_remote_code=trust_remote_code,
use_safetensors=safe_serialisation,
**hub_attrs,
**attrs,
)
metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.framework})
Args:
llm: The LLM instance for this given model.
trust_remote_code: Whether to trust the remote code when loading the model.
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
"""
config, hub_attrs, attrs = process_transformers_config(llm.model_id, trust_remote_code, **attrs)
_, tokenizer_attrs = llm.llm_parameters
quantize_method = llm._quantize_method
safe_serialisation = first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors")
# Disable safe serialization with vLLM
if llm.__llm_implementation__ == "vllm": safe_serialisation = False
metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method if quantize_method is not None else False}
signatures: DictStrAny = {}
if quantize_method == "gptq":
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
model = autogptq.AutoGPTQForCausalLM.from_quantized(llm.model_id, *decls, quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), trust_remote_code=trust_remote_code, use_safetensors=safe_serialisation, **hub_attrs, **attrs,)
metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework})
signatures["generate"] = {"batchable": False}
else:
# this model might be called with --quantize int4, therefore we need to pop this out
# since saving int4 is not yet supported
if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False): attrs.pop("quantization_config")
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, use_safetensors=safe_serialisation, **hub_attrs, **attrs,)
metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.framework})
_tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(
llm.model_id,
trust_remote_code=trust_remote_code,
**hub_attrs,
**tokenizer_attrs,
)
if _tokenizer.pad_token is None: _tokenizer.pad_token = _tokenizer.eos_token
_tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs,)
if _tokenizer.pad_token is None: _tokenizer.pad_token = _tokenizer.eos_token
# NOTE: quick hack to set the loaded into llm object to use with save_pretrained
# to avoid recursive call when the model is not yet available in local store
_object_setattr(llm, "__llm_model__", model)
_object_setattr(llm, "__llm_tokenizer__", _tokenizer)
create_kwargs: DictStrAny= dict(
module="openllm.serialisation.transformers", api_version="v1", context=generate_context(framework_name="openllm"), labels=generate_labels(llm), metadata=metadata,
signatures=signatures if signatures else make_default_signatures(model), options=ModelOptions(),
external_modules=[importlib.import_module(model.__module__), importlib.import_module(_tokenizer.__module__)] if trust_remote_code else None,
)
if "use_tempfs" in inspect.signature(bentoml.models.create).parameters: create_kwargs["use_tempfs"] = False
# NOTE: quick hack to set the loaded into llm object to use with save_pretrained
# to avoid recursive call when the model is not yet available in local store
_object_setattr(llm, "__llm_model__", model)
_object_setattr(llm, "__llm_tokenizer__", _tokenizer)
create_kwargs: DictStrAny = dict(
module="openllm.serialisation.transformers", api_version="v1", context=generate_context(framework_name="openllm"), labels=generate_labels(llm), metadata=metadata, signatures=signatures if signatures else make_default_signatures(model), options=ModelOptions(), external_modules=[importlib.import_module(model.__module__),
importlib.import_module(_tokenizer.__module__)] if trust_remote_code else None,
)
if "use_tempfs" in inspect.signature(bentoml.models.create).parameters: create_kwargs["use_tempfs"] = False
try:
with bentoml.models.create(llm.tag, **create_kwargs) as bentomodel:
save_pretrained(llm, bentomodel.path, safe_serialization=safe_serialisation)
return bentomodel
finally:
# NOTE: We need to free up the cache after importing the model
# in the case where users first run openllm start without the model
# available locally.
if is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
try:
with bentoml.models.create(llm.tag, **create_kwargs) as bentomodel:
save_pretrained(llm, bentomodel.path, safe_serialization=safe_serialisation)
return bentomodel
finally:
# NOTE: We need to free up the cache after importing the model
# in the case where users first run openllm start without the model
# available locally.
if is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
"""Return an instance of ``bentoml.Model`` from given LLM instance.
"""Return an instance of ``bentoml.Model`` from given LLM instance.
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
"""
try:
model = bentoml.models.get(llm.tag)
# compat with bentoml.transformers.get
if model.info.module not in ("openllm.serialisation.transformers", __name__):
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
return model
except bentoml.exceptions.NotFound as err:
if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
raise err from None
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
"""
try:
model = bentoml.models.get(llm.tag)
# compat with bentoml.transformers.get
if model.info.module not in ("openllm.serialisation.transformers", __name__):
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
return model
except bentoml.exceptions.NotFound as err:
if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
raise err from None
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
"""Load the model from BentoML store.
"""Load the model from BentoML store.
By default, it will try to find check the model in the local store.
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
config, hub_attrs, attrs = process_transformers_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs)
safe_serialization = first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors")
if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq":
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
return autogptq.AutoGPTQForCausalLM.from_quantized(
llm._bentomodel.path,
*decls,
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
trust_remote_code=llm.__llm_trust_remote_code__,
use_safetensors=safe_serialization,
**hub_attrs,
**attrs,
)
By default, it will try to find check the model in the local store.
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
config, hub_attrs, attrs = process_transformers_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs)
safe_serialization = first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors")
if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq":
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
return autogptq.AutoGPTQForCausalLM.from_quantized(llm._bentomodel.path, *decls, quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), trust_remote_code=llm.__llm_trust_remote_code__, use_safetensors=safe_serialization, **hub_attrs, **attrs,)
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(
llm._bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.__llm_trust_remote_code__,
**hub_attrs,
**attrs,
)
# NOTE: we only cast and load the model if it is not already quantized and setup correctly
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_quantized", False)
if torch.cuda.is_available() and device_count() == 1 and not loaded_in_kbit:
try: model = model.to("cuda")
except torch.cuda.OutOfMemoryError as err: raise RuntimeError(f"Failed to convert {llm.config['model_name']} with model_id '{llm.model_id}' to CUDA.\nNote: You can try out '--quantize int8 | int4' for dynamic quantization.") from err
# BetterTransformer is currently only supported on PyTorch.
if llm.bettertransformer and isinstance(model, _transformers.PreTrainedModel): model = model.to_bettertransformer()
return t.cast("M", model)
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, **hub_attrs, **attrs,)
# NOTE: we only cast and load the model if it is not already quantized and setup correctly
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_quantized", False)
if torch.cuda.is_available() and device_count() == 1 and not loaded_in_kbit:
try:
model = model.to("cuda")
except torch.cuda.OutOfMemoryError as err:
raise RuntimeError(f"Failed to convert {llm.config['model_name']} with model_id '{llm.model_id}' to CUDA.\nNote: You can try out '--quantize int8 | int4' for dynamic quantization.") from err
# BetterTransformer is currently only supported on PyTorch.
if llm.bettertransformer and isinstance(model, _transformers.PreTrainedModel): model = model.to_bettertransformer()
return t.cast("M", model)
def save_pretrained(
llm: openllm.LLM[M, T],
save_directory: str,
is_main_process: bool = True,
state_dict: DictStrAny | None = None,
save_function: t.Callable[..., None] | None = None,
push_to_hub: bool = False,
max_shard_size: int | str = "10GB",
safe_serialization: bool = False,
variant: str | None = None,
**attrs: t.Any,
) -> None:
"""Light wrapper around ``transformers.PreTrainedTokenizer.save_pretrained`` and ``transformers.PreTrainedModel.save_pretrained``."""
save_function = first_not_none(save_function, default=torch.save)
model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs)
safe_serialization = safe_serialization or llm._serialisation_format == "safetensors"
# NOTE: disable safetensors for vllm
if llm.__llm_implementation__ == "vllm": safe_serialization = False
if llm._quantize_method == "gptq":
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
if not lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})")
t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
elif LazyType["vllm.LLMEngine"]("vllm.LLMEngine").isinstance(llm.model): raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.")
elif isinstance(llm.model, _transformers.Pipeline): llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
else:
# We can safely cast here since it will be the PreTrainedModel protocol.
t.cast("_transformers.PreTrainedModel", llm.model).save_pretrained(
save_directory,
is_main_process=is_main_process,
state_dict=state_dict,
save_function=save_function,
push_to_hub=push_to_hub,
max_shard_size=max_shard_size,
safe_serialization=safe_serialization,
variant=variant,
**model_save_attrs
)
llm.tokenizer.save_pretrained(save_directory, push_to_hub=push_to_hub, **tokenizer_save_attrs)
def save_pretrained(llm: openllm.LLM[M, T], save_directory: str, is_main_process: bool = True, state_dict: DictStrAny | None = None, save_function: t.Callable[..., None] | None = None, push_to_hub: bool = False, max_shard_size: int | str = "10GB", safe_serialization: bool = False, variant: str | None = None, **attrs: t.Any,) -> None:
"""Light wrapper around ``transformers.PreTrainedTokenizer.save_pretrained`` and ``transformers.PreTrainedModel.save_pretrained``."""
save_function = first_not_none(save_function, default=torch.save)
model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs)
safe_serialization = safe_serialization or llm._serialisation_format == "safetensors"
# NOTE: disable safetensors for vllm
if llm.__llm_implementation__ == "vllm": safe_serialization = False
if llm._quantize_method == "gptq":
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
if not lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})")
t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
elif LazyType["vllm.LLMEngine"]("vllm.LLMEngine").isinstance(llm.model):
raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.")
elif isinstance(llm.model, _transformers.Pipeline):
llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
else:
# We can safely cast here since it will be the PreTrainedModel protocol.
t.cast("_transformers.PreTrainedModel", llm.model).save_pretrained(save_directory, is_main_process=is_main_process, state_dict=state_dict, save_function=save_function, push_to_hub=push_to_hub, max_shard_size=max_shard_size, safe_serialization=safe_serialization, variant=variant, **model_save_attrs)
llm.tokenizer.save_pretrained(save_directory, push_to_hub=push_to_hub, **tokenizer_save_attrs)

View File

@@ -26,77 +26,50 @@ import openllm
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from ._types import LiteralRuntime
from ._types import LiteralRuntime
@contextlib.contextmanager
def build_bento(
model: str,
model_id: str | None = None,
quantize: t.Literal["int4", "int8", "gptq"] | None = None,
runtime: t.Literal["ggml", "transformers"] = "transformers",
cleanup: bool = False,
) -> t.Iterator[bentoml.Bento]:
logger.info("Building BentoML for %s", model)
bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime)
yield bento
def build_bento(model: str, model_id: str | None = None, quantize: t.Literal["int4", "int8", "gptq"] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", cleanup: bool = False,) -> t.Iterator[bentoml.Bento]:
logger.info("Building BentoML for %s", model)
bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime)
yield bento
if cleanup:
logger.info("Deleting %s", bento.tag)
bentoml.bentos.delete(bento.tag)
@contextlib.contextmanager
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any,) -> t.Iterator[str]:
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
else: bento_tag = bentoml.Tag.from_taglike(bento)
if image_tag is None: image_tag = str(bento_tag)
executable = shutil.which("docker")
if not executable:
raise RuntimeError("docker executable not found")
try:
logger.info("Building container for %s", bento_tag)
bentoml.container.build(bento_tag, backend="docker", image_tag=(image_tag,), progress="plain", **attrs,)
yield image_tag
finally:
if cleanup:
logger.info("Deleting %s", bento.tag)
bentoml.bentos.delete(bento.tag)
logger.info("Deleting container %s", image_tag)
subprocess.check_output([executable, "rmi", "-f", image_tag])
@contextlib.contextmanager
def build_container(
bento: bentoml.Bento | str | bentoml.Tag,
image_tag: str | None = None,
cleanup: bool = False,
**attrs: t.Any,
) -> t.Iterator[str]:
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
else: bento_tag = bentoml.Tag.from_taglike(bento)
if image_tag is None: image_tag = str(bento_tag)
def prepare(model: str, model_id: str | None = None, implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, cleanup: bool = True,) -> t.Iterator[str]:
if clean_context is None:
clean_context = contextlib.ExitStack()
cleanup = True
executable = shutil.which("docker")
if not executable:
raise RuntimeError("docker executable not found")
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
try:
logger.info("Building container for %s", bento_tag)
bentoml.container.build(
bento_tag,
backend="docker",
image_tag=(image_tag,),
progress="plain",
**attrs,
)
yield image_tag
finally:
if cleanup:
logger.info("Deleting container %s", image_tag)
subprocess.check_output([executable, "rmi", "-f", image_tag])
if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
else: bento = bentoml.get(bento_tag)
container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_")
@contextlib.contextmanager
def prepare(
model: str,
model_id: str | None = None,
implementation: LiteralRuntime = "pt",
deployment_mode: t.Literal["container", "local"] = "local",
clean_context: contextlib.ExitStack | None = None,
cleanup: bool = True,
) -> t.Iterator[str]:
if clean_context is None:
clean_context = contextlib.ExitStack()
cleanup = True
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
else: bento = bentoml.get(bento_tag)
container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_")
if deployment_mode == "container": container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
yield container_name
if cleanup: clean_context.close()
if deployment_mode == "container": container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
yield container_name
if cleanup: clean_context.close()

View File

@@ -50,74 +50,77 @@ from .lazy import LazyModule
logger = logging.getLogger(__name__)
try:
from typing import GenericAlias as _TypingGenericAlias # type: ignore
from typing import GenericAlias as _TypingGenericAlias # type: ignore
except ImportError:
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
_TypingGenericAlias = () # type: ignore
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
_TypingGenericAlias = () # type: ignore
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
else:
# _GenericAlias is the actual GenericAlias implementation
_WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore
# _GenericAlias is the actual GenericAlias implementation
_WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import overload as _overload
from typing import overload as _overload
else:
from typing_extensions import overload as _overload
from typing_extensions import overload as _overload
if t.TYPE_CHECKING:
import openllm
import openllm
from .._types import AnyCallable
from .._types import DictStrAny
from .._types import LiteralRuntime
from .._types import AnyCallable
from .._types import DictStrAny
from .._types import LiteralRuntime
DEV_DEBUG_VAR = "OPENLLMDEVDEBUG"
def set_debug_mode(enabled: bool, level: int = 1) -> None:
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
os.environ[DEBUG_ENV_VAR] = str(enabled)
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
os.environ[DEBUG_ENV_VAR] = str(enabled)
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool:
try: return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
except TypeError:
if isinstance(cls, _WithArgsTypes): return False
raise
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
except TypeError:
if isinstance(cls, _WithArgsTypes): return False
raise
def available_devices() -> tuple[str, ...]:
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
from .._strategies import NvidiaGpuResource
return tuple(NvidiaGpuResource.from_system())
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
from .._strategies import NvidiaGpuResource
return tuple(NvidiaGpuResource.from_system())
@functools.lru_cache(maxsize=128)
def generate_hash_from_file(f: str, algorithm: t.Literal["md5", "sha1"] = "sha1") -> str:
"""Generate a hash from given file's modification time.
"""Generate a hash from given file's modification time.
Args:
f: The file to generate the hash from.
algorithm: The hashing algorithm to use. Defaults to 'sha1' (similar to how Git generate its commit hash.)
Args:
f: The file to generate the hash from.
algorithm: The hashing algorithm to use. Defaults to 'sha1' (similar to how Git generate its commit hash.)
Returns:
The generated hash.
"""
return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()
Returns:
The generated hash.
"""
return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()
@functools.lru_cache(maxsize=1)
def device_count() -> int: return len(available_devices())
def device_count() -> int:
return len(available_devices())
# equivocal setattr to save one lookup per assignment
_object_setattr = object.__setattr__
def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None:
"""This makes sure that we don't overwrite any existing attributes on the object."""
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
if not hasattr(obj, name): _setattr(name, value)
"""This makes sure that we don't overwrite any existing attributes on the object."""
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
if not hasattr(obj, name): _setattr(name, value)
def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str: return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str:
return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
# Special debug flag controled via OPENLLMDEVDEBUG
DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.getenv(DEV_DEBUG_VAR)))
@@ -125,210 +128,201 @@ DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.gete
MYPY = False
SHOW_CODEGEN = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3
def get_debug_mode() -> bool: return DEBUG or _get_debug_mode()
def get_quiet_mode() -> bool: return not DEBUG and _get_quiet_mode()
def get_debug_mode() -> bool:
return DEBUG or _get_debug_mode()
def get_quiet_mode() -> bool:
return not DEBUG and _get_quiet_mode()
class ExceptionFilter(logging.Filter):
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
"""A filter of all exception."""
if exclude_exceptions is None: exclude_exceptions = [ConflictError]
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
super(ExceptionFilter, self).__init__(**kwargs)
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info:
etype, _, _ = record.exc_info
if etype is not None:
for exc in self.EXCLUDE_EXCEPTIONS:
if issubclass(etype, exc): return False
return True
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
"""A filter of all exception."""
if exclude_exceptions is None: exclude_exceptions = [ConflictError]
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
super(ExceptionFilter, self).__init__(**kwargs)
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info:
etype, _, _ = record.exc_info
if etype is not None:
for exc in self.EXCLUDE_EXCEPTIONS:
if issubclass(etype, exc): return False
return True
class InfoFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING
def filter(self, record: logging.LogRecord) -> bool:
return logging.INFO <= record.levelno < logging.WARNING
_LOGGING_CONFIG: DictStrAny = {
"version": 1,
"disable_existing_loggers": True,
"filters": {"excfilter": {"()": "openllm.utils.ExceptionFilter"}, "infofilter": {"()": "openllm.utils.InfoFilter"}},
"handlers": {
"bentomlhandler": {
"class": "logging.StreamHandler",
"filters": ["excfilter", "infofilter"],
"stream": "ext://sys.stdout",
},
"defaulthandler": {
"class": "logging.StreamHandler",
"level": logging.WARNING,
},
},
"loggers": {
"bentoml": {
"handlers": ["bentomlhandler", "defaulthandler"],
"level": logging.INFO,
"propagate": False,
},
"openllm": {
"handlers": ["bentomlhandler", "defaulthandler"],
"level": logging.INFO,
"propagate": False,
},
},
"root": {"level": logging.WARNING},
"version": 1, "disable_existing_loggers": True, "filters": {"excfilter": {"()": "openllm.utils.ExceptionFilter"}, "infofilter": {"()": "openllm.utils.InfoFilter"}}, "handlers": {
"bentomlhandler": {"class": "logging.StreamHandler", "filters": ["excfilter", "infofilter"], "stream": "ext://sys.stdout",}, "defaulthandler": {"class": "logging.StreamHandler", "level": logging.WARNING,},
}, "loggers": {"bentoml": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False,}, "openllm": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False,},}, "root": {"level": logging.WARNING},
}
def configure_logging() -> None:
"""Configure logging for OpenLLM.
"""Configure logging for OpenLLM.
Behaves similar to how BentoML loggers are being configured.
"""
if get_quiet_mode():
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
_LOGGING_CONFIG["root"]["level"] = logging.ERROR
elif get_debug_mode() or DEBUG:
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG
_LOGGING_CONFIG["root"]["level"] = logging.DEBUG
else:
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO
_LOGGING_CONFIG["root"]["level"] = logging.INFO
Behaves similar to how BentoML loggers are being configured.
"""
if get_quiet_mode():
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
_LOGGING_CONFIG["root"]["level"] = logging.ERROR
elif get_debug_mode() or DEBUG:
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG
_LOGGING_CONFIG["root"]["level"] = logging.DEBUG
else:
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO
_LOGGING_CONFIG["root"]["level"] = logging.INFO
logging.config.dictConfig(_LOGGING_CONFIG)
logging.config.dictConfig(_LOGGING_CONFIG)
@functools.lru_cache(maxsize=1)
def in_notebook() -> bool:
try:
from IPython.core.getipython import get_ipython
if "IPKernelApp" not in get_ipython().config: return False
except ImportError: return False
except AttributeError: return False
return True
try:
from IPython.core.getipython import get_ipython
if "IPKernelApp" not in get_ipython().config: return False
except ImportError:
return False
except AttributeError:
return False
return True
_dockerenv, _cgroup = Path("/.dockerenv"), Path("/proc/self/cgroup")
class suppress(contextlib.suppress, contextlib.ContextDecorator):
"""A version of contextlib.suppress with decorator support.
"""A version of contextlib.suppress with decorator support.
>>> @suppress(KeyError)
... def key_error():
... {}['']
>>> key_error()
"""
>>> @suppress(KeyError)
... def key_error():
... {}['']
>>> key_error()
"""
def compose(*funcs: AnyCallable) -> AnyCallable:
"""Compose any number of unary functions into a single unary function.
"""Compose any number of unary functions into a single unary function.
>>> import textwrap
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
>>> strip_and_dedent(compose.__doc__) == expected
True
>>> import textwrap
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
>>> strip_and_dedent(compose.__doc__) == expected
True
Compose also allows the innermost function to take arbitrary arguments.
Compose also allows the innermost function to take arbitrary arguments.
>>> round_three = lambda x: round(x, ndigits=3)
>>> f = compose(round_three, int.__truediv__)
>>> [f(3*x, x+1) for x in range(1,10)]
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs)
>>> round_three = lambda x: round(x, ndigits=3)
>>> f = compose(round_three, int.__truediv__)
>>> [f(3*x, x+1) for x in range(1,10)]
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable:
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs)
def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
```python
@apply(reversed)
def get_numbers(start):
"doc for get_numbers"
return range(start, start+3)
list(get_numbers(4))
# [6, 5, 4]
```
```python
get_numbers.__doc__
# 'doc for get_numbers'
```
"""
return lambda func: functools.wraps(func)(compose(transform, func))
```python
@apply(reversed)
def get_numbers(start):
"doc for get_numbers"
return range(start, start+3)
list(get_numbers(4))
# [6, 5, 4]
```
```python
get_numbers.__doc__
# 'doc for get_numbers'
```
"""
return lambda func: functools.wraps(func)(compose(transform, func))
@apply(bool)
@suppress(FileNotFoundError)
def _text_in_file(text: str, filename: Path) -> bool: return any(text in line for line in filename.open())
def _text_in_file(text: str, filename: Path) -> bool:
return any(text in line for line in filename.open())
def in_docker() -> bool:
"""Is this current environment running in docker?
"""Is this current environment running in docker?
```python
type(in_docker())
```
"""
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
```python
type(in_docker())
```
"""
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
T, K = t.TypeVar("T"), t.TypeVar("K")
def resolve_filepath(path: str, ctx: str | None = None) -> str:
"""Resolve a file path to an absolute path, expand user and environment variables."""
try: return resolve_user_filepath(path, ctx)
except FileNotFoundError: return path
"""Resolve a file path to an absolute path, expand user and environment variables."""
try:
return resolve_user_filepath(path, ctx)
except FileNotFoundError:
return path
def validate_is_path(maybe_path: str) -> bool: return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
def validate_is_path(maybe_path: str) -> bool:
return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
def generate_context(framework_name: str) -> _ModelContext:
from .import_utils import is_flax_available
from .import_utils import is_tf_available
from .import_utils import is_torch_available
from .import_utils import is_flax_available
from .import_utils import is_tf_available
from .import_utils import is_torch_available
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
if is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
if is_tf_available():
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
framework_versions["tensorflow"] = get_tf_version()
if is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")})
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
if is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
if is_tf_available():
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
framework_versions["tensorflow"] = get_tf_version()
if is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")})
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> DictStrAny:
return {
"runtime": llm.runtime,
"framework": "openllm",
"model_name": llm.config["model_name"],
"architecture": llm.config["architecture"],
"serialisation_format": llm._serialisation_format,
}
return {"runtime": llm.runtime, "framework": "openllm", "model_name": llm.config["model_name"], "architecture": llm.config["architecture"], "serialisation_format": llm._serialisation_format,}
_TOKENIZER_PREFIX = "_tokenizer_"
def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[DictStrAny, DictStrAny]:
"""Normalize the given attrs to a model and tokenizer kwargs accordingly."""
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
for k in tuple(attrs.keys()):
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
return attrs, tokenizer_attrs
"""Normalize the given attrs to a model and tokenizer kwargs accordingly."""
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX):]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
for k in tuple(attrs.keys()):
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
return attrs, tokenizer_attrs
@_overload
def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]: ...
def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]:
...
@_overload
def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]: ...
def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]:
...
@_overload
def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]: ...
def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]:
...
@_overload
def infer_auto_class(implementation: t.Literal["vllm"]) -> type[openllm.AutoVLLM]: ...
def infer_auto_class(implementation: t.Literal["vllm"]) -> type[openllm.AutoVLLM]:
...
def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM] | type[openllm.AutoTFLLM] | type[openllm.AutoFlaxLLM] | type[openllm.AutoVLLM]:
if implementation == "tf":
from openllm import AutoTFLLM
return AutoTFLLM
elif implementation == "flax":
from openllm import AutoFlaxLLM
return AutoFlaxLLM
elif implementation == "pt":
from openllm import AutoLLM
return AutoLLM
elif implementation == "vllm":
from openllm import AutoVLLM
return AutoVLLM
else: raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')")
if implementation == "tf":
from openllm import AutoTFLLM
return AutoTFLLM
elif implementation == "flax":
from openllm import AutoFlaxLLM
return AutoFlaxLLM
elif implementation == "pt":
from openllm import AutoLLM
return AutoLLM
elif implementation == "vllm":
from openllm import AutoVLLM
return AutoVLLM
else:
raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')")
# NOTE: The set marks contains a set of modules name
# that are available above and are whitelisted
@@ -337,80 +331,52 @@ _whitelist_modules = {"pkg"}
# XXX: define all classes, functions import above this line
# since _extras will be the locals() import from this file.
_extras: dict[str, t.Any] = {
k: v
for k, v in locals().items()
if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_"))
}
_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_"))}
_extras["__openllm_migration__"] = {"ModelEnv": "EnvVarMixin"}
_import_structure: dict[str, list[str]] = {
"analytics": [],
"codegen": [],
"dantic": [],
"representation": ["ReprMixin"],
"lazy": ["LazyModule"],
"import_utils": [
"OPTIONAL_DEPENDENCIES",
"ENV_VARS_TRUE_VALUES",
"DummyMetaclass",
"EnvVarMixin",
"requires_dependencies",
"is_cpm_kernels_available",
"is_einops_available",
"is_flax_available",
"is_tf_available",
"is_vllm_available",
"is_torch_available",
"is_bitsandbytes_available",
"is_peft_available",
"is_datasets_available",
"is_transformers_supports_kbit",
"is_transformers_supports_agent",
"is_jupyter_available",
"is_jupytext_available",
"is_notebook_available",
"is_triton_available",
"is_autogptq_available",
"require_backends",
"analytics": [], "codegen": [], "dantic": [], "representation": ["ReprMixin"], "lazy": ["LazyModule"], "import_utils": [
"OPTIONAL_DEPENDENCIES", "ENV_VARS_TRUE_VALUES", "DummyMetaclass", "EnvVarMixin", "requires_dependencies", "is_cpm_kernels_available", "is_einops_available", "is_flax_available", "is_tf_available", "is_vllm_available", "is_torch_available", "is_bitsandbytes_available", "is_peft_available", "is_datasets_available", "is_transformers_supports_kbit",
"is_transformers_supports_agent", "is_jupyter_available", "is_jupytext_available", "is_notebook_available", "is_triton_available", "is_autogptq_available", "require_backends",
],
}
if t.TYPE_CHECKING:
# NOTE: The following exports useful utils from bentoml
from . import LazyLoader as LazyLoader
from . import LazyType as LazyType
from . import analytics as analytics
from . import bentoml_cattr as bentoml_cattr
from . import codegen as codegen
from . import configure_logging as configure_logging
from . import dantic as dantic
from . import first_not_none as first_not_none
from . import reserve_free_port as reserve_free_port
from . import set_quiet_mode as set_quiet_mode
from . import validate_is_path as validate_is_path
from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES
from .import_utils import OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES
from .import_utils import DummyMetaclass as DummyMetaclass
from .import_utils import EnvVarMixin as EnvVarMixin
from .import_utils import is_autogptq_available as is_autogptq_available
from .import_utils import is_bitsandbytes_available as is_bitsandbytes_available
from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available
from .import_utils import is_datasets_available as is_datasets_available
from .import_utils import is_einops_available as is_einops_available
from .import_utils import is_flax_available as is_flax_available
from .import_utils import is_jupyter_available as is_jupyter_available
from .import_utils import is_jupytext_available as is_jupytext_available
from .import_utils import is_notebook_available as is_notebook_available
from .import_utils import is_peft_available as is_peft_available
from .import_utils import is_tf_available as is_tf_available
from .import_utils import is_torch_available as is_torch_available
from .import_utils import is_transformers_supports_agent as is_transformers_supports_agent
from .import_utils import is_transformers_supports_kbit as is_transformers_supports_kbit
from .import_utils import is_triton_available as is_triton_available
from .import_utils import is_vllm_available as is_vllm_available
from .import_utils import require_backends as require_backends
from .import_utils import requires_dependencies as requires_dependencies
from .representation import ReprMixin as ReprMixin
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects=_extras)
# NOTE: The following exports useful utils from bentoml
from . import LazyLoader as LazyLoader
from . import LazyType as LazyType
from . import analytics as analytics
from . import bentoml_cattr as bentoml_cattr
from . import codegen as codegen
from . import configure_logging as configure_logging
from . import dantic as dantic
from . import first_not_none as first_not_none
from . import reserve_free_port as reserve_free_port
from . import set_quiet_mode as set_quiet_mode
from . import validate_is_path as validate_is_path
from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES
from .import_utils import OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES
from .import_utils import DummyMetaclass as DummyMetaclass
from .import_utils import EnvVarMixin as EnvVarMixin
from .import_utils import is_autogptq_available as is_autogptq_available
from .import_utils import is_bitsandbytes_available as is_bitsandbytes_available
from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available
from .import_utils import is_datasets_available as is_datasets_available
from .import_utils import is_einops_available as is_einops_available
from .import_utils import is_flax_available as is_flax_available
from .import_utils import is_jupyter_available as is_jupyter_available
from .import_utils import is_jupytext_available as is_jupytext_available
from .import_utils import is_notebook_available as is_notebook_available
from .import_utils import is_peft_available as is_peft_available
from .import_utils import is_tf_available as is_tf_available
from .import_utils import is_torch_available as is_torch_available
from .import_utils import is_transformers_supports_agent as is_transformers_supports_agent
from .import_utils import is_transformers_supports_kbit as is_transformers_supports_kbit
from .import_utils import is_triton_available as is_triton_available
from .import_utils import is_vllm_available as is_vllm_available
from .import_utils import require_backends as require_backends
from .import_utils import requires_dependencies as requires_dependencies
from .representation import ReprMixin as ReprMixin
else:
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects=_extras)

View File

@@ -30,12 +30,11 @@ import openllm
from bentoml._internal.utils import analytics as _internal_analytics
if t.TYPE_CHECKING:
from .._types import P
from .._types import T
from .._types import P
from .._types import T
logger = logging.getLogger(__name__)
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
# This variable is a proxy that will control BENTOML_DO_NOT_TRACK
@@ -43,87 +42,78 @@ OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK"
DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper()
@functools.lru_cache(maxsize=1)
def do_not_track() -> bool:
return DO_NOT_TRACK in ENV_VARS_TRUE_VALUES
return DO_NOT_TRACK in ENV_VARS_TRUE_VALUES
@functools.lru_cache(maxsize=1)
def _usage_event_debugging() -> bool:
# For BentoML developers only - debug and print event payload if turned on
return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true"
# For BentoML developers only - debug and print event payload if turned on
return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true"
def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
try:
return func(*args, **kwargs)
except Exception as err:
if _usage_event_debugging():
if openllm.utils.get_debug_mode():
logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3)
else:
logger.info("Tracking Error: %s", err)
else:
logger.debug("Tracking Error: %s", err)
return wrapper
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
try:
return func(*args, **kwargs)
except Exception as err:
if _usage_event_debugging():
if openllm.utils.get_debug_mode():
logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3)
else:
logger.info("Tracking Error: %s", err)
else:
logger.debug("Tracking Error: %s", err)
return wrapper
@silent
def track(event_properties: attr.AttrsInstance) -> None:
if do_not_track():
return
_internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties))
if do_not_track():
return
_internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties))
@contextlib.contextmanager
def set_bentoml_tracking() -> t.Generator[None, None, None]:
original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False))
try:
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track())
yield
finally:
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value
original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False))
try:
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track())
yield
finally:
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value
class EventMeta:
@property
def event_name(self) -> str:
# camel case to snake case
event_name = re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()
# remove "_event" suffix
suffix_to_remove = "_event"
if event_name.endswith(suffix_to_remove):
event_name = event_name[: -len(suffix_to_remove)]
return event_name
@property
def event_name(self) -> str:
# camel case to snake case
event_name = re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()
# remove "_event" suffix
suffix_to_remove = "_event"
if event_name.endswith(suffix_to_remove):
event_name = event_name[:-len(suffix_to_remove)]
return event_name
@attr.define
class OpenllmCliEvent(EventMeta):
cmd_group: str
cmd_name: str
openllm_version: str = importlib.metadata.version("openllm")
# NOTE: reserved for the do_not_track logics
duration_in_ms: t.Any = attr.field(default=None)
error_type: str = attr.field(default=None)
return_code: int = attr.field(default=None)
cmd_group: str
cmd_name: str
openllm_version: str = importlib.metadata.version("openllm")
# NOTE: reserved for the do_not_track logics
duration_in_ms: t.Any = attr.field(default=None)
error_type: str = attr.field(default=None)
return_code: int = attr.field(default=None)
@attr.define
class StartInitEvent(EventMeta):
model_name: str
llm_config: t.Dict[str, t.Any] = attr.field(default=None)
@staticmethod
def handler(llm_config: openllm.LLMConfig) -> StartInitEvent:
return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump())
model_name: str
llm_config: t.Dict[str, t.Any] = attr.field(default=None)
@staticmethod
def handler(llm_config: openllm.LLMConfig) -> StartInitEvent:
return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump())
def track_start_init(llm_config: openllm.LLMConfig) -> None:
if do_not_track():
return
track(StartInitEvent.handler(llm_config))
if do_not_track():
return
track(StartInitEvent.handler(llm_config))

View File

@@ -26,19 +26,19 @@ from pathlib import Path
import orjson
if t.TYPE_CHECKING:
from fs.base import FS
from fs.base import FS
import openllm
import openllm
from .._types import AnyCallable
from .._types import DictStrAny
from .._types import ListStr
from .._types import AnyCallable
from .._types import DictStrAny
from .._types import ListStr
PartialAny = functools.partial[t.Any]
PartialAny = functools.partial[t.Any]
else:
DictStrAny = dict
ListStr = list
PartialAny = functools.partial
DictStrAny = dict
ListStr = list
PartialAny = functools.partial
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
@@ -47,274 +47,206 @@ logger = logging.getLogger(__name__)
OPENLLM_MODEL_NAME = "# openllm: model name"
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
class ModelNameFormatter(string.Formatter):
model_keyword: t.LiteralString = "__model_name__"
model_keyword: t.LiteralString = "__model_name__"
def __init__(self, model_name: str):
"""The formatter that extends model_name to be formatted the 'service.py'."""
super().__init__()
self.model_name = model_name
def __init__(self, model_name: str):
"""The formatter that extends model_name to be formatted the 'service.py'."""
super().__init__()
self.model_name = model_name
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any:
return super().vformat(format_string, (), {self.model_keyword: self.model_name})
def can_format(self, value: str) -> bool:
try:
self.parse(value)
return True
except ValueError:
return False
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any:
return super().vformat(format_string, (), {self.model_keyword: self.model_name})
def can_format(self, value: str) -> bool:
try:
self.parse(value)
return True
except ValueError:
return False
class ModelIdFormatter(ModelNameFormatter):
model_keyword: t.LiteralString = "__model_id__"
model_keyword: t.LiteralString = "__model_id__"
class ModelAdapterMapFormatter(ModelNameFormatter):
model_keyword: t.LiteralString = "__model_adapter_map__"
model_keyword: t.LiteralString = "__model_adapter_map__"
_service_file = Path(__file__).parent.parent / "_service.py"
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
from . import DEBUG
from . import DEBUG
model_name = llm.config["model_name"]
model_name = llm.config["model_name"]
logger.debug("Generating service for %s", model_name)
logger.debug("Generating service for %s", model_name)
with open(_service_file.__fspath__(), "r") as f:
src_contents = f.readlines()
with open(_service_file.__fspath__(), "r") as f:
src_contents = f.readlines()
# modify with model name
for it in src_contents:
if OPENLLM_MODEL_NAME in it:
src_contents[src_contents.index(it)] = (
ModelNameFormatter(model_name).vformat(it)[: -(len(OPENLLM_MODEL_NAME) + 3)] + "\n"
)
elif OPENLLM_MODEL_ADAPTER_MAP in it:
src_contents[src_contents.index(it)] = (
ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[
: -(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)
]
+ "\n"
)
# modify with model name
for it in src_contents:
if OPENLLM_MODEL_NAME in it:
src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + "\n")
elif OPENLLM_MODEL_ADAPTER_MAP in it:
src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + "\n")
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
if DEBUG:
logger.info("Generated script:\n%s", script)
llm_fs.writetext(llm.config["service_name"], script)
if DEBUG:
logger.info("Generated script:\n%s", script)
llm_fs.writetext(llm.config["service_name"], script)
# NOTE: The following ins extracted from attrs internal APIs
# sentinel object for unequivocal object() getattr
_sentinel = object()
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool:
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
attr = getattr(cls, attrib_name, _sentinel)
if attr is _sentinel:
return False
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
attr = getattr(cls, attrib_name, _sentinel)
if attr is _sentinel:
return False
for base_cls in cls.__mro__[1:]:
a = getattr(base_cls, attrib_name, None)
if attr is a:
return False
return True
for base_cls in cls.__mro__[1:]:
a = getattr(base_cls, attrib_name, None)
if attr is a:
return False
return True
def get_annotations(cls: type[t.Any]) -> DictStrAny:
"""Get annotations for *cls*."""
if has_own_attribute(cls, "__annotations__"):
return cls.__annotations__
"""Get annotations for *cls*."""
if has_own_attribute(cls, "__annotations__"):
return cls.__annotations__
return DictStrAny()
_classvar_prefixes = (
"typing.ClassVar",
"t.ClassVar",
"ClassVar",
"typing_extensions.ClassVar",
)
return DictStrAny()
_classvar_prefixes = ("typing.ClassVar", "t.ClassVar", "ClassVar", "typing_extensions.ClassVar",)
def is_class_var(annot: str | t.Any) -> bool:
"""Check whether *annot* is a typing.ClassVar.
"""Check whether *annot* is a typing.ClassVar.
The string comparison hack is used to avoid evaluating all string
annotations which would put attrs-based classes at a performance
disadvantage compared to plain old classes.
"""
annot = str(annot)
The string comparison hack is used to avoid evaluating all string
annotations which would put attrs-based classes at a performance
disadvantage compared to plain old classes.
"""
annot = str(annot)
# Annotation can be quoted.
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
annot = annot[1:-1]
return annot.startswith(_classvar_prefixes)
# Annotation can be quoted.
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
annot = annot[1:-1]
return annot.startswith(_classvar_prefixes)
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
"""Add __module__ and __qualname__ to a *method* if possible."""
try: method_or_cls.__module__ = cls.__module__
except AttributeError: pass
try: method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
except AttributeError: pass
try: method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}."
except AttributeError: pass
return method_or_cls
"""Add __module__ and __qualname__ to a *method* if possible."""
try:
method_or_cls.__module__ = cls.__module__
except AttributeError:
pass
try:
method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
except AttributeError:
pass
try:
method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}."
except AttributeError:
pass
return method_or_cls
# Exec the script with the given global (globs) and local (locs) variables.
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None: eval(compile(script, filename, "exec"), globs, locs) # noqa: S307
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None:
eval(compile(script, filename, "exec"), globs, locs) # noqa: S307
# ported from attrs
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
"""Create the method with the script given and return the method object."""
locs: DictStrAny = {}
# In order of debuggers like PDB being able to step through the code,
# we add a fake linecache entry.
count = 1
base_filename = filename
while True:
linecache_tuple = (len(script), None, script.splitlines(True), filename)
old_val = linecache.cache.setdefault(filename, linecache_tuple)
if old_val == linecache_tuple: break
else:
filename = f"{base_filename[:-1]}-{count}>"
count += 1
_compile_and_eval(script, globs, locs, filename)
return locs[name]
"""Create the method with the script given and return the method object."""
locs: DictStrAny = {}
# In order of debuggers like PDB being able to step through the code,
# we add a fake linecache entry.
count = 1
base_filename = filename
while True:
linecache_tuple = (len(script), None, script.splitlines(True), filename)
old_val = linecache.cache.setdefault(filename, linecache_tuple)
if old_val == linecache_tuple: break
else:
filename = f"{base_filename[:-1]}-{count}>"
count += 1
_compile_and_eval(script, globs, locs, filename)
return locs[name]
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]:
"""Create a tuple subclass to hold class attributes.
"""Create a tuple subclass to hold class attributes.
The subclass is a bare tuple with properties for names.
The subclass is a bare tuple with properties for names.
class MyClassAttributes(tuple):
__slots__ = ()
x = property(itemgetter(0))
"""
from . import SHOW_CODEGEN
class MyClassAttributes(tuple):
__slots__ = ()
x = property(itemgetter(0))
"""
from . import SHOW_CODEGEN
attr_class_name = f"{cls_name}Attributes"
attr_class_template = [
f"class {attr_class_name}(tuple):",
" __slots__ = ()",
]
if attr_names:
for i, attr_name in enumerate(attr_names): attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
else: attr_class_template.append(" pass")
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template))
_compile_and_eval("\n".join(attr_class_template), globs)
return globs[attr_class_name]
attr_class_name = f"{cls_name}Attributes"
attr_class_template = [f"class {attr_class_name}(tuple):", " __slots__ = ()",]
if attr_names:
for i, attr_name in enumerate(attr_names):
attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
else:
attr_class_template.append(" pass")
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template))
_compile_and_eval("\n".join(attr_class_template), globs)
return globs[attr_class_name]
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str:
return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None,) -> AnyCallable:
from . import SHOW_CODEGEN
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
if annotations: meth.__annotations__ = annotations
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
def generate_function(
typ: type[t.Any],
func_name: str,
lines: list[str] | None,
args: tuple[str, ...] | None,
globs: dict[str, t.Any],
annotations: dict[str, t.Any] | None = None,
) -> AnyCallable:
from . import SHOW_CODEGEN
return meth
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
if annotations: meth.__annotations__ = annotations
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: t.LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable:
from . import dantic
from . import field_env_key
return meth
def identity(_: str, x_value: t.Any) -> t.Any:
return x_value
default_callback = identity if default_callback is None else default_callback
def make_env_transformer(
cls: type[openllm.LLMConfig],
model_name: str,
suffix: t.LiteralString | None = None,
default_callback: t.Callable[[str, t.Any], t.Any] | None = None,
globs: DictStrAny | None = None,
) -> AnyCallable:
from . import dantic
from . import field_env_key
globs = {} if globs is None else globs
globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,})
def identity(_: str, x_value: t.Any) -> t.Any:
return x_value
lines: ListStr = [
"__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", "return [", " f.evolve(", " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", " metadata={", " 'env': f.metadata.get('env', __env(f.name)),", " 'description': f.metadata.get('description', '(not provided)'),", " },",
" )", " for f in fields", "]",
]
fields_ann = "list[attr.Attribute[t.Any]]"
default_callback = identity if default_callback is None else default_callback
globs = {} if globs is None else globs
globs.update(
{
"__populate_env": dantic.env_converter,
"__default_callback": default_callback,
"__field_env": field_env_key,
"__suffix": suffix or "",
"__model_name": model_name,
}
)
lines: ListStr = [
"__env = lambda field_name: __field_env(__model_name, field_name, __suffix)",
"return [",
" f.evolve(",
" default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),",
" metadata={",
" 'env': f.metadata.get('env', __env(f.name)),",
" 'description': f.metadata.get('description', '(not provided)'),",
" },",
" )",
" for f in fields",
"]",
]
fields_ann = "list[attr.Attribute[t.Any]]"
return generate_function(
cls,
"__auto_env",
lines,
args=("_", "fields"),
globs=globs,
annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann},
)
return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann},)
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
"""Enhance sdk with nice repr that plays well with your brain."""
from .representation import ReprMixin
"""Enhance sdk with nice repr that plays well with your brain."""
from .representation import ReprMixin
if name is None: name = func.__name__.strip("_")
_signatures = inspect.signature(func).parameters
def _repr(self: ReprMixin) -> str: return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
else: doc = func.__doc__
return t.cast(_T, functools.update_wrapper(
types.new_class(
name,
(PartialAny, ReprMixin),
exec_body=lambda ns: ns.update(
{
"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]),
"__repr_args__": _repr_args,
"__repr__": _repr,
"__doc__": inspect.cleandoc(doc),
"__module__": "openllm",
}
),
)(func, **attrs),
func,
))
if name is None: name = func.__name__.strip("_")
_signatures = inspect.signature(func).parameters
def _repr(self: ReprMixin) -> str:
return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]:
return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
else: doc = func.__doc__
return t.cast(_T, functools.update_wrapper(types.new_class(name, (PartialAny, ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,))

View File

@@ -31,478 +31,432 @@ from click import shell_completion as sc
from click import types as click_types
if t.TYPE_CHECKING:
from attr import _ValidatorType
from attr import _ValidatorType
from .._types import ListAny
from .._types import ListAny
_T = t.TypeVar("_T")
AnyCallable = t.Callable[..., t.Any]
FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command])
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: type[t.Any] | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]:
# TODO: support parsing nested attrs class and Union
envvar = field.metadata["env"]
dasherized = inflection.dasherize(name)
underscored = inflection.underscore(name)
def attrs_to_options(
name: str,
field: attr.Attribute[t.Any],
model_name: str,
typ: type[t.Any] | None = None,
suffix_generation: bool = False,
suffix_sampling: bool = False,
) -> t.Callable[[FC], FC]:
# TODO: support parsing nested attrs class and Union
envvar = field.metadata["env"]
dasherized = inflection.dasherize(name)
underscored = inflection.underscore(name)
if typ in (None, attr.NOTHING):
typ = field.type
if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
if typ in (None, attr.NOTHING):
typ = field.type
if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
full_option_name = f"--{dasherized}"
if field.type is bool: full_option_name += f"/--no-{dasherized}"
if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
else: identifier = f"{model_name}_{underscored}"
return cog.optgroup.option(
identifier,
full_option_name,
type=parse_type(typ),
required=field.default is attr.NOTHING,
default=field.default if field.default not in (attr.NOTHING, None) else None,
show_default=True,
multiple=allows_multiple(typ) if typ else False,
help=field.metadata.get("description", "(No description provided)"),
show_envvar=True,
envvar=envvar,
)
full_option_name = f"--{dasherized}"
if field.type is bool: full_option_name += f"/--no-{dasherized}"
if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
else: identifier = f"{model_name}_{underscored}"
return cog.optgroup.option(identifier, full_option_name, type=parse_type(typ), required=field.default is attr.NOTHING, default=field.default if field.default not in (attr.NOTHING, None) else None, show_default=True, multiple=allows_multiple(typ) if typ else False, help=field.metadata.get("description", "(No description provided)"), show_envvar=True, envvar=envvar,)
def env_converter(value: t.Any, env: str | None = None) -> t.Any:
if env is not None:
value = os.environ.get(env, value)
if value is not None and isinstance(value, str):
try:
return orjson.loads(value.lower())
except orjson.JSONDecodeError as err:
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
return value
if env is not None:
value = os.environ.get(env, value)
if value is not None and isinstance(value, str):
try:
return orjson.loads(value.lower())
except orjson.JSONDecodeError as err:
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
return value
def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[_T] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any,) -> t.Any:
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
def Field(
default: t.Any = None,
*,
ge: int | float | None = None,
le: int | float | None = None,
validator: _ValidatorType[_T] | None = None,
description: str | None = None,
env: str | None = None,
auto_default: bool = False,
use_default_converter: bool = True,
**attrs: t.Any,
) -> t.Any:
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
By default, if both validator and ge are provided, then then ge will be
piped into first, then all of the other validator will be run afterwards.
By default, if both validator and ge are provided, then then ge will be
piped into first, then all of the other validator will be run afterwards.
Args:
default: The default value for ``dantic.Field``. Defaults to ``None``.
ge: Greater than or equal to. Defaults to None.
le: Less than or equal to. Defaults to None.
validator: Optional attrs-compatible validators type. Default to None
description: the documentation for the field. Defaults to None.
env: the environment variable to read from. Defaults to None.
auto_default: a bool indicating whether to use the default value as the environment.
Defaults to False. If set to True, the behaviour of this Field will also depends
on kw_only. If kw_only=True, the this field will become 'Required' and the default
value is omitted. If kw_only=False, then the default value will be used as before.
use_default_converter: a bool indicating whether to use the default converter. Defaults
to True. If set to False, then the default converter will not be used.
The default converter converts a given value from the environment variable
for this given Field.
**attrs: The rest of the arguments are passed to attr.field
"""
metadata = attrs.pop("metadata", {})
if description is None:
description = "(No description provided)"
metadata["description"] = description
if env is not None:
metadata["env"] = env
piped: list[_ValidatorType[t.Any]] = []
Args:
default: The default value for ``dantic.Field``. Defaults to ``None``.
ge: Greater than or equal to. Defaults to None.
le: Less than or equal to. Defaults to None.
validator: Optional attrs-compatible validators type. Default to None
description: the documentation for the field. Defaults to None.
env: the environment variable to read from. Defaults to None.
auto_default: a bool indicating whether to use the default value as the environment.
Defaults to False. If set to True, the behaviour of this Field will also depends
on kw_only. If kw_only=True, the this field will become 'Required' and the default
value is omitted. If kw_only=False, then the default value will be used as before.
use_default_converter: a bool indicating whether to use the default converter. Defaults
to True. If set to False, then the default converter will not be used.
The default converter converts a given value from the environment variable
for this given Field.
**attrs: The rest of the arguments are passed to attr.field
"""
metadata = attrs.pop("metadata", {})
if description is None:
description = "(No description provided)"
metadata["description"] = description
if env is not None:
metadata["env"] = env
piped: list[_ValidatorType[t.Any]] = []
converter = attrs.pop("converter", None)
if use_default_converter:
converter = functools.partial(env_converter, env=env)
converter = attrs.pop("converter", None)
if use_default_converter:
converter = functools.partial(env_converter, env=env)
if ge is not None:
piped.append(attr.validators.ge(ge))
if le is not None:
piped.append(attr.validators.le(le))
if validator is not None:
piped.append(validator)
if ge is not None:
piped.append(attr.validators.ge(ge))
if le is not None:
piped.append(attr.validators.le(le))
if validator is not None:
piped.append(validator)
if len(piped) == 0:
_validator = None
elif len(piped) == 1:
_validator = piped[0]
else:
_validator = attr.validators.and_(*piped)
if len(piped) == 0:
_validator = None
elif len(piped) == 1:
_validator = piped[0]
else:
_validator = attr.validators.and_(*piped)
factory = attrs.pop("factory", None)
if factory is not None and default is not None:
raise RuntimeError("'factory' and 'default' are mutually exclusive.")
# NOTE: the behaviour of this is we will respect factory over the default
if factory is not None:
attrs["factory"] = factory
else:
attrs["default"] = default
factory = attrs.pop("factory", None)
if factory is not None and default is not None:
raise RuntimeError("'factory' and 'default' are mutually exclusive.")
# NOTE: the behaviour of this is we will respect factory over the default
if factory is not None:
attrs["factory"] = factory
else:
attrs["default"] = default
kw_only = attrs.pop("kw_only", False)
if auto_default and kw_only:
attrs.pop("default")
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
kw_only = attrs.pop("kw_only", False)
if auto_default and kw_only:
attrs.pop("default")
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType]:
"""Transforms the pydantic field's type into a click-compatible type.
"""Transforms the pydantic field's type into a click-compatible type.
Args:
field_type: pydantic field type
Args:
field_type: pydantic field type
Returns:
ParamType: click type equivalent
"""
from . import lenient_issubclass
if t.get_origin(field_type) is t.Union:
raise NotImplementedError("Unions are not supported")
# enumeration strings or other Enum derivatives
if lenient_issubclass(field_type, Enum):
return EnumChoice(enum=field_type, case_sensitive=True)
# literals are enum-like with way less functionality
if is_literal(field_type):
return LiteralChoice(value=field_type, case_sensitive=True)
# modules, classes, functions
if is_typing(field_type):
return ModuleType()
# entire dictionaries:
# using a Dict, convert in advance
if is_mapping(field_type):
return JsonType()
# list, List[p], Tuple[p], Set[p] and so on
if is_container(field_type):
return parse_container_args(field_type)
# bytes are not natively supported by click
if lenient_issubclass(field_type, bytes):
return BytesType()
# return the current type: it should be a primitive
return field_type
Returns:
ParamType: click type equivalent
"""
from . import lenient_issubclass
if t.get_origin(field_type) is t.Union:
raise NotImplementedError("Unions are not supported")
# enumeration strings or other Enum derivatives
if lenient_issubclass(field_type, Enum):
return EnumChoice(enum=field_type, case_sensitive=True)
# literals are enum-like with way less functionality
if is_literal(field_type):
return LiteralChoice(value=field_type, case_sensitive=True)
# modules, classes, functions
if is_typing(field_type):
return ModuleType()
# entire dictionaries:
# using a Dict, convert in advance
if is_mapping(field_type):
return JsonType()
# list, List[p], Tuple[p], Set[p] and so on
if is_container(field_type):
return parse_container_args(field_type)
# bytes are not natively supported by click
if lenient_issubclass(field_type, bytes):
return BytesType()
# return the current type: it should be a primitive
return field_type
def is_typing(field_type: type) -> bool:
"""Checks whether the current type is a module-like type.
"""Checks whether the current type is a module-like type.
Args:
field_type: pydantic field type
Args:
field_type: pydantic field type
Returns:
bool: true if the type is itself a type
"""
raw = t.get_origin(field_type)
if raw is None:
return False
if raw is type or raw is t.Type:
return True
Returns:
bool: true if the type is itself a type
"""
raw = t.get_origin(field_type)
if raw is None:
return False
if raw is type or raw is t.Type:
return True
return False
def is_literal(field_type: type) -> bool:
"""Checks whether the given field type is a Literal type or not.
"""Checks whether the given field type is a Literal type or not.
Literals are weird: isinstance and subclass do not work, so you compare
the origin with the Literal declaration itself.
Literals are weird: isinstance and subclass do not work, so you compare
the origin with the Literal declaration itself.
Args:
field_type: current pydantic type
Returns:
bool: true if Literal type, false otherwise
"""
origin = t.get_origin(field_type)
return origin is not None and origin is t.Literal
Args:
field_type: current pydantic type
Returns:
bool: true if Literal type, false otherwise
"""
origin = t.get_origin(field_type)
return origin is not None and origin is t.Literal
class ModuleType(ParamType):
name = "module"
name = "module"
def _import_object(self, value: str) -> t.Any:
module_name, class_name = value.rsplit(".", maxsplit=1)
if not all(s.isidentifier() for s in module_name.split(".")):
raise ValueError(f"'{value}' is not a valid module name")
if not class_name.isidentifier():
raise ValueError(f"Variable '{class_name}' is not a valid identifier")
def _import_object(self, value: str) -> t.Any:
module_name, class_name = value.rsplit(".", maxsplit=1)
if not all(s.isidentifier() for s in module_name.split(".")):
raise ValueError(f"'{value}' is not a valid module name")
if not class_name.isidentifier():
raise ValueError(f"Variable '{class_name}' is not a valid identifier")
module = importlib.import_module(module_name)
if class_name:
try:
return getattr(module, class_name)
except AttributeError:
raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
try:
if isinstance(value, str):
return self._import_object(value)
return value
except Exception as exc:
self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
module = importlib.import_module(module_name)
if class_name:
try:
return getattr(module, class_name)
except AttributeError:
raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
try:
if isinstance(value, str):
return self._import_object(value)
return value
except Exception as exc:
self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
class EnumChoice(click.Choice):
name = "enum"
name = "enum"
def __init__(self, enum: Enum, case_sensitive: bool = False):
"""Enum type support for click that extends ``click.Choice``.
def __init__(self, enum: Enum, case_sensitive: bool = False):
"""Enum type support for click that extends ``click.Choice``.
Args:
enum: Given enum
case_sensitive: Whether this choice should be case case_sensitive.
"""
self.mapping = enum
self.internal_type = type(enum)
choices: ListAny = [e.name for e in enum.__class__]
super().__init__(choices, case_sensitive)
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
if isinstance(value, self.internal_type):
return value
result = super().convert(value, param, ctx)
if isinstance(result, str):
result = self.internal_type[result]
return result
Args:
enum: Given enum
case_sensitive: Whether this choice should be case case_sensitive.
"""
self.mapping = enum
self.internal_type = type(enum)
choices: ListAny = [e.name for e in enum.__class__]
super().__init__(choices, case_sensitive)
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
if isinstance(value, self.internal_type):
return value
result = super().convert(value, param, ctx)
if isinstance(result, str):
result = self.internal_type[result]
return result
class LiteralChoice(EnumChoice):
name = "literal"
def __init__(self, value: t.Any, case_sensitive: bool = False):
"""Literal support for click."""
# expect every literal value to belong to the same primitive type
values = list(value.__args__)
item_type = type(values[0])
if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
_mapping = {str(v): v for v in values}
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
self.internal_type = item_type
name = "literal"
def __init__(self, value: t.Any, case_sensitive: bool = False):
"""Literal support for click."""
# expect every literal value to belong to the same primitive type
values = list(value.__args__)
item_type = type(values[0])
if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
_mapping = {str(v): v for v in values}
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
self.internal_type = item_type
def allows_multiple(field_type: type[t.Any]) -> bool:
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
For containers, it exploits click's support for lists and such to use the same option multiple times
to create a complex object: `python run.py --subsets train --subsets test`
# becomes `subsets: ["train", "test"]`.
For containers, it exploits click's support for lists and such to use the same option multiple times
to create a complex object: `python run.py --subsets train --subsets test`
# becomes `subsets: ["train", "test"]`.
Args:
field_type: pydantic type.
Args:
field_type: pydantic type.
Returns:
bool: true if it's a composite field (lists, containers and so on), false otherwise
"""
# Early out for mappings, since it's better to deal with them using strings.
if is_mapping(field_type):
return False
# Activate multiple option for (simple) container types
if is_container(field_type):
args = parse_container_args(field_type)
# A non-composite type has a single argument, such as 'List[int]'
# A composite type has a tuple of arguments, like 'Tuple[str, int, int]'.
# For the moment, only non-composite types are allowed.
return not isinstance(args, tuple)
Returns:
bool: true if it's a composite field (lists, containers and so on), false otherwise
"""
# Early out for mappings, since it's better to deal with them using strings.
if is_mapping(field_type):
return False
# Activate multiple option for (simple) container types
if is_container(field_type):
args = parse_container_args(field_type)
# A non-composite type has a single argument, such as 'List[int]'
# A composite type has a tuple of arguments, like 'Tuple[str, int, int]'.
# For the moment, only non-composite types are allowed.
return not isinstance(args, tuple)
return False
def is_mapping(field_type: type) -> bool:
"""Checks whether this field represents a dictionary or JSON object.
"""Checks whether this field represents a dictionary or JSON object.
Args:
field_type (type): pydantic type
Returns:
bool: true when the field is a dict-like object, false otherwise.
"""
# Early out for standard containers.
from . import lenient_issubclass
if lenient_issubclass(field_type, t.Mapping): return True
# for everything else or when the typing is more complex, check its origin
origin = t.get_origin(field_type)
if origin is None: return False
return lenient_issubclass(origin, t.Mapping)
Args:
field_type (type): pydantic type
Returns:
bool: true when the field is a dict-like object, false otherwise.
"""
# Early out for standard containers.
from . import lenient_issubclass
if lenient_issubclass(field_type, t.Mapping): return True
# for everything else or when the typing is more complex, check its origin
origin = t.get_origin(field_type)
if origin is None: return False
return lenient_issubclass(origin, t.Mapping)
def is_container(field_type: type) -> bool:
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
Args:
field_type: pydantic field type
Returns:
bool: true if a container, false otherwise
"""
# do not consider strings or byte arrays as containers
if field_type in (str, bytes): return False
# Early out for standard containers: list, tuple, range
from . import lenient_issubclass
if lenient_issubclass(field_type, t.Container): return True
origin = t.get_origin(field_type)
# Early out for non-typing objects
if origin is None: return False
return lenient_issubclass(origin, t.Container)
Args:
field_type: pydantic field type
Returns:
bool: true if a container, false otherwise
"""
# do not consider strings or byte arrays as containers
if field_type in (str, bytes): return False
# Early out for standard containers: list, tuple, range
from . import lenient_issubclass
if lenient_issubclass(field_type, t.Container): return True
origin = t.get_origin(field_type)
# Early out for non-typing objects
if origin is None: return False
return lenient_issubclass(origin, t.Container)
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType]:
"""Parses the arguments inside a container type (lists, tuples and so on).
"""Parses the arguments inside a container type (lists, tuples and so on).
Args:
field_type: pydantic field type
Returns:
ParamType | tuple[ParamType]: single click-compatible type or a tuple
"""
if not is_container(field_type):
raise ValueError("Field type is not a container type.")
args = t.get_args(field_type)
# Early out for untyped containers: standard lists, tuples, List[Any]
# Use strings when the type is unknown, avoid click's type guessing
if len(args) == 0:
return click_types.convert_type(str)
# Early out for homogenous containers: Tuple[int], List[str]
if len(args) == 1:
return parse_single_arg(args[0])
# Early out for homogenous tuples of indefinite length: Tuple[int, ...]
if len(args) == 2 and args[1] is Ellipsis:
return parse_single_arg(args[0])
# Then deal with fixed-length containers: Tuple[str, int, int]
return tuple(parse_single_arg(arg) for arg in args)
Args:
field_type: pydantic field type
Returns:
ParamType | tuple[ParamType]: single click-compatible type or a tuple
"""
if not is_container(field_type):
raise ValueError("Field type is not a container type.")
args = t.get_args(field_type)
# Early out for untyped containers: standard lists, tuples, List[Any]
# Use strings when the type is unknown, avoid click's type guessing
if len(args) == 0:
return click_types.convert_type(str)
# Early out for homogenous containers: Tuple[int], List[str]
if len(args) == 1:
return parse_single_arg(args[0])
# Early out for homogenous tuples of indefinite length: Tuple[int, ...]
if len(args) == 2 and args[1] is Ellipsis:
return parse_single_arg(args[0])
# Then deal with fixed-length containers: Tuple[str, int, int]
return tuple(parse_single_arg(arg) for arg in args)
def parse_single_arg(arg: type) -> ParamType:
"""Returns the click-compatible type for container origin types.
"""Returns the click-compatible type for container origin types.
In this case, returns string when it's not inferrable, a JSON for mappings
and the original type itself in every other case (ints, floats and so on).
Bytes is a special case, not natively handled by click.
In this case, returns string when it's not inferrable, a JSON for mappings
and the original type itself in every other case (ints, floats and so on).
Bytes is a special case, not natively handled by click.
Args:
arg (type): single argument
Returns:
ParamType: click-compatible type
"""
from . import lenient_issubclass
# When we don't know the type, we choose 'str'
if arg is t.Any: return click_types.convert_type(str)
# For containers and nested models, we use JSON
if is_container(arg): return JsonType()
if lenient_issubclass(arg, bytes): return BytesType()
return click_types.convert_type(arg)
Args:
arg (type): single argument
Returns:
ParamType: click-compatible type
"""
from . import lenient_issubclass
# When we don't know the type, we choose 'str'
if arg is t.Any: return click_types.convert_type(str)
# For containers and nested models, we use JSON
if is_container(arg): return JsonType()
if lenient_issubclass(arg, bytes): return BytesType()
return click_types.convert_type(arg)
class BytesType(ParamType):
name = "bytes"
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
if isinstance(value, bytes):
return value
try:
return str.encode(value)
except Exception as exc:
self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
name = "bytes"
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
if isinstance(value, bytes):
return value
try:
return str.encode(value)
except Exception as exc:
self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
CYGWIN = sys.platform.startswith("cygwin")
WIN = sys.platform.startswith("win")
if sys.platform.startswith("win") and WIN:
def _get_argv_encoding() -> str:
import locale
def _get_argv_encoding() -> str:
import locale
return locale.getpreferredencoding()
return locale.getpreferredencoding()
else:
def _get_argv_encoding() -> str:
return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
def _get_argv_encoding() -> str:
return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
class CudaValueType(ParamType):
name = "cuda"
envvar_list_splitter = ","
is_composite = True
typ = click_types.convert_type(str)
name = "cuda"
envvar_list_splitter = ","
is_composite = True
typ = click_types.convert_type(str)
def split_envvar_value(self, rv: str) -> t.Sequence[str]:
var = tuple(i for i in rv.split(self.envvar_list_splitter))
if "-1" in var:
return var[: var.index("-1")]
return var
def split_envvar_value(self, rv: str) -> t.Sequence[str]:
var = tuple(i for i in rv.split(self.envvar_list_splitter))
if "-1" in var:
return var[:var.index("-1")]
return var
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
Args:
ctx: Invocation context for this command.
param: The parameter that is requesting completion.
incomplete: Value being completed. May be empty.
"""
from ..utils import available_devices
Args:
ctx: Invocation context for this command.
param: The parameter that is requesting completion.
incomplete: Value being completed. May be empty.
"""
from ..utils import available_devices
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
if isinstance(value, bytes):
enc = _get_argv_encoding()
try:
value = value.decode(enc)
except UnicodeError:
fs_enc = sys.getfilesystemencoding()
if fs_enc != enc:
try:
value = value.decode(fs_enc)
except UnicodeError:
value = value.decode("utf-8", "replace")
else:
value = value.decode("utf-8", "replace")
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
if isinstance(value, bytes):
enc = _get_argv_encoding()
try:
value = value.decode(enc)
except UnicodeError:
fs_enc = sys.getfilesystemencoding()
if fs_enc != enc:
try:
value = value.decode(fs_enc)
except UnicodeError:
value = value.decode("utf-8", "replace")
else:
value = value.decode("utf-8", "replace")
return tuple(self.typ(x, param, ctx) for x in value.split(","))
def __repr__(self) -> str:
"""CUDA is a click.STRING extension."""
return "STRING"
return tuple(self.typ(x, param, ctx) for x in value.split(","))
def __repr__(self) -> str:
"""CUDA is a click.STRING extension."""
return "STRING"
CUDA = CudaValueType()
class JsonType(ParamType):
name = "json"
name = "json"
def __init__(self, should_load: bool = True) -> None:
"""Support JSON type for click.ParamType.
def __init__(self, should_load: bool = True) -> None:
"""Support JSON type for click.ParamType.
Args:
should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
"""
super().__init__()
self.should_load = should_load
Args:
should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
"""
super().__init__()
self.should_load = should_load
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
from . import LazyType
if LazyType[t.Mapping[str, str]](t.Mapping[str, str]).isinstance(value) or not self.should_load: return value
try: return orjson.loads(value)
except orjson.JSONDecodeError as exc: self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
from . import LazyType
if LazyType[t.Mapping[str, str]](t.Mapping[str, str]).isinstance(value) or not self.should_load: return value
try:
return orjson.loads(value)
except orjson.JSONDecodeError as exc:
self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)

View File

@@ -19,27 +19,24 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
if t.TYPE_CHECKING:
from ..models.auto.factory import _LazyAutoMapping
from ..models.auto.factory import _LazyAutoMapping
class FlaxFlanT5(metaclass=DummyMetaclass):
_backends = ["flax"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["flax"])
_backends = ["flax"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["flax"])
class FlaxOPT(metaclass=DummyMetaclass):
_backends = ["flax"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["flax"])
_backends = ["flax"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["flax"])
class AutoFlaxLLM(metaclass=DummyMetaclass):
_backends = ["flax"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["flax"])
_backends = ["flax"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["flax"])
MODEL_FLAX_MAPPING = t.cast("_LazyAutoMapping", None)

View File

@@ -19,14 +19,13 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
class ChatGLM(metaclass=DummyMetaclass):
_backends = ["torch", "cpm_kernels"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["torch", "cpm_kernels"])
_backends = ["torch", "cpm_kernels"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["torch", "cpm_kernels"])
class Baichuan(metaclass=DummyMetaclass):
_backends = ["torch", "cpm_kernels"]
_backends = ["torch", "cpm_kernels"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["torch", "cpm_kernels"])
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["torch", "cpm_kernels"])

View File

@@ -19,7 +19,7 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
class Falcon(metaclass=DummyMetaclass):
_backends = ["torch", "einops"]
_backends = ["torch", "einops"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["torch", "einops"])
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["torch", "einops"])

View File

@@ -19,7 +19,7 @@ from ..utils import DummyMetaclass
from ..utils import require_backends
class MPT(metaclass=DummyMetaclass):
_backends = ["torch", "triton"]
_backends = ["torch", "triton"]
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["torch"])
def __init__(self, *args: t.Any, **attrs: t.Any):
require_backends(self, ["torch"])

Some files were not shown because too many files have changed in this diff Show More