mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-11 18:09:52 -04:00
style: define experimental guidelines (#168)
This commit is contained in:
@@ -32,6 +32,12 @@ repos:
|
||||
types: [python]
|
||||
exclude: ^(docs|tools|tests)
|
||||
args: [--config=pyproject.toml]
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.40.1
|
||||
hooks:
|
||||
- id: yapf
|
||||
types: [python]
|
||||
args: [--parallel, --recursive]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
|
||||
@@ -155,8 +155,13 @@ hatch run tests:snapshot-models
|
||||
|
||||
## Working with Git
|
||||
|
||||
To filter out most of the generated commits for infrastructure, use ``--invert-grep`` in conjunction with ``--grep``
|
||||
to filter out all commits with regex `"[generated]"`
|
||||
To filter out most of the generated commits for infrastructure, use
|
||||
`--invert-grep` in conjunction with `--grep` to filter out all commits with
|
||||
regex `"[generated]"`
|
||||
|
||||
## Style
|
||||
|
||||
See [STYLE.md](STYLE.md) for our style guide.
|
||||
|
||||
## Releasing a New Version
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
<img src="https://img.shields.io/pypi/pyversions/openllm.svg?logo=python&label=Python&logoColor=gold" alt="python_version" />
|
||||
</a><a href="https://github.com/pypa/hatch">
|
||||
<img src="https://img.shields.io/badge/%F0%9F%A5%9A-Hatch-4051b5.svg" alt="Hatch" />
|
||||
</a><a href="https://github.com/bentoml/OpenLLM/blob/main/STYLE.md">
|
||||
<img src="https://img.shields.io/badge/code%20style-experimental-000000.svg" alt="code style" />
|
||||
</a><a href="https://github.com/astral-sh/ruff">
|
||||
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json" alt="Ruff" />
|
||||
</a><a href="https://github.com/python/mypy">
|
||||
|
||||
160
STYLE.md
Normal file
160
STYLE.md
Normal file
@@ -0,0 +1,160 @@
|
||||
## the coding style
|
||||
|
||||
This documentation serves as a brief discussion of the coding style used for
|
||||
OpenLLM. As you have noticed, it is different from the conventional
|
||||
[PEP8](https://peps.python.org/pep-0008/) style as across many Python projects.
|
||||
The manifestation of OpenLLM code style is a combination of
|
||||
[Google Python Style](https://google.github.io/styleguide/pyguide.html),
|
||||
inspiration from coding language such as APL, Haskell, and is designed for fast,
|
||||
experimental development and prototyping.
|
||||
|
||||
Everyone always has their own opinions on style. I believe this is exemplified
|
||||
further within the Python community, as it tries to be beginner-friendly, and
|
||||
therefore most people hold a very strong opinion on styling. I don't have a
|
||||
strong opinion on style either (I don't have any issue with PEP8, as we use it
|
||||
for our other projects), as long as:
|
||||
|
||||
- You don't use any linter, formatter that change the style drastically other
|
||||
than what specified within the projects' [`pyproject.toml`](./pyproject.toml).
|
||||
- The code you contribute is not widely different from the style of the code
|
||||
surrounding it.
|
||||
|
||||
With that being said, I want to use this project as a playground the explore a
|
||||
style that is both: "feels natural" and expressive for mathematical reasoning. I
|
||||
hope that you find this guide somewhat thought-provoking and interesting, that
|
||||
you can iterate and try to adopt some of them as part of the process
|
||||
contributing to the library.
|
||||
|
||||
While PEP8 is a great base for a style guide, I find it to be having way too
|
||||
much white spaces and makes the code feels 'robotic'. Having a deterministic
|
||||
style and formatter is great to reduce the overhead of stylistic discussions,
|
||||
but I think it is important to write code that express the intent of reasoning.
|
||||
(_The policy here is definitely not "shovel everything into one line", but
|
||||
rather "compact and flowing"_)
|
||||
|
||||
The styling is heavily inspired by
|
||||
[Kenneth Iverson's](https://en.wikipedia.org/wiki/Kenneth_E._Iverson) 1979
|
||||
Turing award lecture,
|
||||
[Notation as a Tool of Thought](https://www.eecg.toronto.edu/~jzhu/csc326/readings/iverson.pdf),
|
||||
and a lot of the stylistic inspiration comes from
|
||||
[Jeremy Howard's](https://jeremy.fast.ai/) [fastai](https://docs.fast.ai/). One
|
||||
thing that has been stuck with me ever since is the idea of "brevity facilitates
|
||||
reasoning", as such the tersity of style aren't just for the sake of shortness,
|
||||
rather the brevity of expression. (it enables
|
||||
[expository programming](http://archive.vector.org.uk/art10000980), combining
|
||||
with prototyping new ideas and logics within models implementation)
|
||||
|
||||
## some guidelines
|
||||
|
||||
Though I have stopped using deterministic formatter and linter, I do understand
|
||||
that people have preferences for using these tools, and it plays nicely with IDE
|
||||
and editors. As such, I included a [`pyproject.toml`](./pyproject.toml) file
|
||||
that specifies some configuration for the tools that makes it compiliant with
|
||||
the repository's style. In short, some of the tools include `ruff`, `yapf`, and
|
||||
`interrogate`. Since we manage everything via `hatch`, refer back to the
|
||||
[DEVELOPMENT.md](./DEVELOPMENT.md) for more information on this.
|
||||
|
||||
Overtime, Python has incorporated a lot of features that supports this style of
|
||||
coding, including list comprehension, generator expression, lambda, array-based
|
||||
programming. Yet, Python will remain verbose per se, and the goal is that to
|
||||
make code fit nicely on a screen, and we don't have to always scroll downwards.
|
||||
|
||||
While brevity is important, it is also important to make sure functions are
|
||||
somewhat, type-safe. Since there is no real type-safety when working with
|
||||
Python, typing should be a best-effort to make sure we don't introduce too many
|
||||
bugs.
|
||||
|
||||
### naming
|
||||
|
||||
- follow Python standard for this, I don't have too much opinion on this. Just
|
||||
make sure that it is descriptive, and the abbreviation describes the intent of
|
||||
the variable. i.e: `to_gpu` instead of `t_gpu`, `to_cpu` instead of `t_cpu`.
|
||||
- any math-related notation or neural net layers should be expressive and stay
|
||||
close to the paper as much as possible. For example: `lm_head.weight` instead
|
||||
of `lm_head.w`. Espically for implementing custom kernels and layers, it is
|
||||
crucial to follow its nomenclature. E.g: `conv1` instead of
|
||||
`first_conv_layer`.
|
||||
|
||||
_If you have any suggestions, feel free to give it on our discord server!_
|
||||
|
||||
### layout
|
||||
|
||||
- Preferably not a lot of whitespaces, but rather flowing. If you can fit
|
||||
everything for `if`, `def` or a `return` within one line, then there's no need
|
||||
to break it into multiple line:
|
||||
|
||||
```python
|
||||
def foo(x): return rotate_cv(x) if x > 0 else -x
|
||||
```
|
||||
|
||||
- imports should be grouped by their types, and each import should be designated
|
||||
on its own line.
|
||||
|
||||
```python
|
||||
import os
|
||||
import sys
|
||||
```
|
||||
This is partially to make it easier to work with merge-conflicts, and easier
|
||||
for IDE to navigate context definition.
|
||||
|
||||
- indent with 2 spaces, which follow the Google codestyle.
|
||||
|
||||
- With regards to writing operator, try to follow the domain-specific notation.
|
||||
I.e: when writing pathlib, just don't add space since that is not how you
|
||||
write a path in the terminal. `yapf` will try to accommodate some of this
|
||||
changes.
|
||||
|
||||
- Avoid trailing whitespace
|
||||
|
||||
- use array, pytorch or numpy-based indexing where possible.
|
||||
|
||||
- If you need to export anything, put it in `__all__` or do lazy export for
|
||||
type-safe checker.
|
||||
|
||||
### misc
|
||||
|
||||
- import alias should be concise and descriptive. A convention is to always
|
||||
`import typing as t`.
|
||||
- Writing docstring when it is possible. No need to comment everything asn it
|
||||
makes the codebase hard to read. For docstring, follow the Google style guide.
|
||||
- We do lazy imports, so consult some of the `__init__.py` to see how we do it.
|
||||
- Documentation is still _working-in-progress_, but tldr it will be written in
|
||||
MDX and will be hosted on the GitHub Pages, so stay tuned!
|
||||
- If anything that is not used for runtime, just put it under `t.TYPE_CHECKING`
|
||||
|
||||
### note on codegen
|
||||
|
||||
- We also do some codegen for some of the assignment functions. These logics are
|
||||
largely based on the work of [attrs](https://github.com/python-attrs/attrs) to
|
||||
ensure fast and isolated codegen in Python. If you need codegen but don't know
|
||||
how it works, feel free to mention @aarnphm_ on discord!
|
||||
|
||||
## FAQ
|
||||
|
||||
### Why not use `black`?
|
||||
|
||||
`black` is used on our other projects, but I rather find `black` to be very
|
||||
verbose and overtime it is annoying to work with too much whitespaces.
|
||||
|
||||
### Why not PEP8?
|
||||
|
||||
PEP8 is great if you are writing library such as this, but I'm going to do a lot
|
||||
of experimenting for implementing papers, so I decided early on that PEP8 is
|
||||
probably not fit here, and want to explore more expressive style.
|
||||
|
||||
### Editor is complaining about the style, what should I do?
|
||||
|
||||
Kindly ask you to disable linting for this project 🤗. I will try my best to
|
||||
accomodate with ruff and yapf, but I don't want to spend too much time on this.
|
||||
It is pretty stragithforward to disable it in your editor, with google.
|
||||
|
||||
### Style might put off new contributors?
|
||||
|
||||
I don't think so, as mentioned before, I don't have too much opinion on style as
|
||||
long as it somewhat follow what I have described above or the style of the code
|
||||
surrounding it. I will still accept styles PR as long as it is not too drastic.
|
||||
Just make sure to add the revision to `.git-blame-ignore-revs` so that
|
||||
`git blame` would work correctly.
|
||||
|
||||
As for people who are too close-minded about styling, such individuals aren't
|
||||
the ones we want to work with anyway!
|
||||
3
changelog.d/168.chore.md
Normal file
3
changelog.d/168.chore.md
Normal file
@@ -0,0 +1,3 @@
|
||||
Define specific style guideline for the project. See
|
||||
[STYLE.md](https://github.com/bentoml/OpenLLM/blob/main/STYLE.md) for more
|
||||
information.
|
||||
@@ -24,13 +24,11 @@ llm_runner = openllm.Runner(model, llm_config=llm_config)
|
||||
|
||||
svc = bentoml.Service(name="llm-service", runners=[llm_runner])
|
||||
|
||||
|
||||
@svc.on_startup
|
||||
def download(_: bentoml.Context):
|
||||
llm_runner.download_model()
|
||||
|
||||
llm_runner.download_model()
|
||||
|
||||
@svc.api(input=bentoml.io.Text(), output=bentoml.io.Text())
|
||||
async def prompt(input_text: str) -> str:
|
||||
answer = await llm_runner.generate.async_run(input_text)
|
||||
return answer[0]["generated_text"]
|
||||
answer = await llm_runner.generate.async_run(input_text)
|
||||
return answer[0]["generated_text"]
|
||||
|
||||
@@ -25,23 +25,20 @@ from bentoml.io import JSON
|
||||
from bentoml.io import Text
|
||||
|
||||
class Query(BaseModel):
|
||||
industry: str
|
||||
product_name: str
|
||||
keywords: t.List[str]
|
||||
llm_config: t.Dict[str, t.Any]
|
||||
|
||||
industry: str
|
||||
product_name: str
|
||||
keywords: t.List[str]
|
||||
llm_config: t.Dict[str, t.Any]
|
||||
|
||||
def gen_llm(model_name: str, model_id: str | None = None) -> OpenLLM:
|
||||
lc_llm = OpenLLM(model_name=model_name, model_id=model_id, embedded=False)
|
||||
lc_llm.runner.download_model()
|
||||
return lc_llm
|
||||
|
||||
lc_llm = OpenLLM(model_name=model_name, model_id=model_id, embedded=False)
|
||||
lc_llm.runner.download_model()
|
||||
return lc_llm
|
||||
|
||||
llm = gen_llm("dolly-v2", model_id="databricks/dolly-v2-7b")
|
||||
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["industry", "product_name", "keywords"],
|
||||
template="""
|
||||
input_variables=["industry", "product_name", "keywords"], template="""
|
||||
You are a Facebook Ads Copywriter with a strong background in persuasive
|
||||
writing and marketing. You craft compelling copy that appeals to the target
|
||||
audience's emotions and needs, peruading them to take action or make a
|
||||
@@ -59,22 +56,12 @@ chain = LLMChain(llm=llm, prompt=prompt)
|
||||
|
||||
svc = bentoml.Service("fb-ads-copy", runners=[llm.runner])
|
||||
|
||||
|
||||
@svc.on_startup
|
||||
def download(_: bentoml.Context):
|
||||
llm.runner.download_model()
|
||||
|
||||
|
||||
SAMPLE_INPUT = Query(
|
||||
industry="SAAS",
|
||||
product_name="BentoML",
|
||||
keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"],
|
||||
llm_config=llm.runner.config.model_dump(),
|
||||
)
|
||||
llm.runner.download_model()
|
||||
|
||||
SAMPLE_INPUT = Query(industry="SAAS", product_name="BentoML", keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"], llm_config=llm.runner.config.model_dump(),)
|
||||
|
||||
@svc.api(input=JSON.from_sample(sample=SAMPLE_INPUT), output=Text())
|
||||
def generate(query: Query):
|
||||
return chain.run(
|
||||
{"industry": query.industry, "product_name": query.product_name, "keywords": ", ".join(query.keywords)}
|
||||
)
|
||||
return chain.run({"industry": query.industry, "product_name": query.product_name, "keywords": ", ".join(query.keywords)})
|
||||
|
||||
@@ -22,16 +22,11 @@ from bentoml.io import Text
|
||||
|
||||
SAMPLE_INPUT = "What is the weather in San Francisco?"
|
||||
|
||||
llm = OpenLLM(
|
||||
model_name="dolly-v2",
|
||||
model_id="databricks/dolly-v2-7b",
|
||||
embedded=False,
|
||||
)
|
||||
llm = OpenLLM(model_name="dolly-v2", model_id="databricks/dolly-v2-7b", embedded=False,)
|
||||
tools = load_tools(["serpapi"], llm=llm)
|
||||
agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)
|
||||
svc = bentoml.Service("langchain-openllm", runners=[llm.runner])
|
||||
|
||||
|
||||
@svc.api(input=Text.from_sample(sample=SAMPLE_INPUT), output=Text())
|
||||
def chat(input_text: str):
|
||||
return agent.run(input_text)
|
||||
return agent.run(input_text)
|
||||
|
||||
@@ -205,6 +205,7 @@ ignore = [
|
||||
"PLR0915",
|
||||
"PLR2004", # magic value to use constant
|
||||
"E501", # ignore line length violation
|
||||
"E702",
|
||||
"PYI021", # ignore docstring in stubs, as pyright will include docstring in stubs.
|
||||
"D103", # Just missing docstring for magic methods.
|
||||
"D102",
|
||||
@@ -262,6 +263,49 @@ avoid-escape = false
|
||||
]
|
||||
"typings/**/*" = ["D", "F", "E", "PYI002"]
|
||||
|
||||
[tool.yapf]
|
||||
ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = true
|
||||
ALLOW_MULTILINE_DICTIONARY_KEYS = false
|
||||
ALLOW_MULTILINE_LAMBDAS = false
|
||||
ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS = false
|
||||
ALLOW_SPLIT_BEFORE_DICT_VALUE = false
|
||||
ARITHMETIC_PRECEDENCE_INDICATION = true
|
||||
BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION = 1
|
||||
BLANK_LINES_BETWEEN_TOP_LEVEL_IMPORTS_AND_VARIABLES = 1
|
||||
BLANK_LINE_BEFORE_CLASS_DOCSTRING = false
|
||||
BLANK_LINE_BEFORE_MODULE_DOCSTRING = false
|
||||
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = false
|
||||
COALESCE_BRACKETS = true
|
||||
COLUMN_LIMIT = 384
|
||||
CONTINUATION_ALIGN_STYLE = "VALIGN-RIGHT"
|
||||
DEDENT_CLOSING_BRACKETS = true
|
||||
DISABLE_ENDING_COMMA_HEURISTIC = true
|
||||
EACH_DICT_ENTRY_ON_SEPARATE_LINE = false
|
||||
INDENT_BLANK_LINES = false
|
||||
INDENT_CLOSING_BRACKETS = false
|
||||
INDENT_WIDTH = 2
|
||||
JOIN_MULTIPLE_LINES = true
|
||||
NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = true
|
||||
SPACES_AROUND_SUBSCRIPT_COLON = false
|
||||
SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = false
|
||||
SPACE_INSIDE_BRACKETS = false
|
||||
SPLIT_ALL_COMMA_SEPARATED_VALUES = false
|
||||
SPLIT_ALL_TOP_LEVEL_COMMA_SEPARATED_VALUES = false
|
||||
SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED = false
|
||||
SPLIT_BEFORE_BITWISE_OPERATOR = false
|
||||
SPLIT_BEFORE_CLOSING_BRACKET = false
|
||||
SPLIT_BEFORE_DICT_SET_GENERATOR = false
|
||||
SPLIT_BEFORE_DOT = false
|
||||
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = false
|
||||
SPLIT_BEFORE_FIRST_ARGUMENT = false
|
||||
SPLIT_BEFORE_LOGICAL_OPERATOR = false
|
||||
SPLIT_BEFORE_NAMED_ASSIGNS = false
|
||||
SPLIT_COMPLEX_COMPREHENSION = false
|
||||
SPLIT_PENALTY_AFTER_OPENING_BRACKET = 10000
|
||||
SPLIT_PENALTY_BEFORE_IF_EXPR = 10000
|
||||
SPLIT_PENALTY_COMPREHENSION = 3000
|
||||
SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT = 8000
|
||||
|
||||
[tool.coverage.paths]
|
||||
openllm = ["src/openllm", "*/openllm/src/openllm"]
|
||||
[tool.coverage.run]
|
||||
|
||||
@@ -32,244 +32,217 @@ from . import utils as utils
|
||||
from .exceptions import MissingDependencyError
|
||||
|
||||
if utils.DEBUG:
|
||||
utils.set_debug_mode(True)
|
||||
utils.set_quiet_mode(False)
|
||||
logging.basicConfig(level=logging.NOTSET)
|
||||
utils.set_debug_mode(True)
|
||||
utils.set_quiet_mode(False)
|
||||
logging.basicConfig(level=logging.NOTSET)
|
||||
else:
|
||||
# configuration for bitsandbytes before import
|
||||
os.environ["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
# The following warnings from bitsandbytes, and probably not that important
|
||||
# for users to see when DEBUG is False
|
||||
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
|
||||
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
|
||||
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.")
|
||||
|
||||
# configuration for bitsandbytes before import
|
||||
os.environ["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
# The following warnings from bitsandbytes, and probably not that important
|
||||
# for users to see when DEBUG is False
|
||||
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
|
||||
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
|
||||
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.")
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"],
|
||||
"_configuration": ["LLMConfig"],
|
||||
"_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs"],
|
||||
"_generation": ["StopSequenceCriteria", "StopOnTokens"],
|
||||
"_quantisation": ["infer_quantisation_config"],
|
||||
"exceptions": [],
|
||||
"utils": ["infer_auto_class"],
|
||||
"models": [],
|
||||
"client": [],
|
||||
"bundle": [],
|
||||
"playground": [],
|
||||
"testing": [],
|
||||
"serialisation": ["ggml", "transformers"],
|
||||
"cli.entrypoint": ["start", "start_grpc", "build", "import_model", "list_models"],
|
||||
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"], "_configuration": ["LLMConfig"], "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs"], "_generation": ["StopSequenceCriteria", "StopOnTokens"], "_quantisation": ["infer_quantisation_config"], "exceptions": [], "utils": ["infer_auto_class"], "models": [],
|
||||
"client": [], "bundle": [], "playground": [], "testing": [], "serialisation": ["ggml", "transformers"], "cli.entrypoint": ["start", "start_grpc", "build", "import_model", "list_models"],
|
||||
# NOTE: models
|
||||
"models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES", ],
|
||||
"models.chatglm": ["ChatGLMConfig"],
|
||||
"models.baichuan": ["BaichuanConfig"],
|
||||
"models.dolly_v2": ["DollyV2Config"],
|
||||
"models.falcon": ["FalconConfig"],
|
||||
"models.flan_t5": ["FlanT5Config"],
|
||||
"models.gpt_neox": ["GPTNeoXConfig"],
|
||||
"models.llama": ["LlamaConfig"],
|
||||
"models.mpt": ["MPTConfig"],
|
||||
"models.opt": ["OPTConfig"],
|
||||
"models.stablelm": ["StableLMConfig"],
|
||||
"models.starcoder": ["StarCoderConfig"],
|
||||
"models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"], "models.chatglm": ["ChatGLMConfig"], "models.baichuan": ["BaichuanConfig"], "models.dolly_v2": ["DollyV2Config"], "models.falcon": ["FalconConfig"], "models.flan_t5": ["FlanT5Config"], "models.gpt_neox": ["GPTNeoXConfig"],
|
||||
"models.llama": ["LlamaConfig"], "models.mpt": ["MPTConfig"], "models.opt": ["OPTConfig"], "models.stablelm": ["StableLMConfig"], "models.starcoder": ["StarCoderConfig"],
|
||||
}
|
||||
|
||||
# NOTE: torch and cpm_kernels
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError
|
||||
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils import dummy_pt_and_cpm_kernels_objects
|
||||
_import_structure["utils.dummy_pt_and_cpm_kernels_objects"] = [name for name in dir(dummy_pt_and_cpm_kernels_objects) if not name.startswith("_")]
|
||||
from .utils import dummy_pt_and_cpm_kernels_objects
|
||||
_import_structure["utils.dummy_pt_and_cpm_kernels_objects"] = [name for name in dir(dummy_pt_and_cpm_kernels_objects) if not name.startswith("_")]
|
||||
else:
|
||||
_import_structure["models.chatglm"].extend(["ChatGLM"])
|
||||
_import_structure["models.baichuan"].extend(["Baichuan"])
|
||||
_import_structure["models.chatglm"].extend(["ChatGLM"])
|
||||
_import_structure["models.baichuan"].extend(["Baichuan"])
|
||||
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError
|
||||
if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils import dummy_pt_and_einops_objects
|
||||
_import_structure["utils.dummy_pt_and_einops_objects"] = [name for name in dir(dummy_pt_and_einops_objects) if not name.startswith("_")]
|
||||
from .utils import dummy_pt_and_einops_objects
|
||||
_import_structure["utils.dummy_pt_and_einops_objects"] = [name for name in dir(dummy_pt_and_einops_objects) if not name.startswith("_")]
|
||||
else:
|
||||
_import_structure["models.falcon"].extend(["Falcon"])
|
||||
_import_structure["models.falcon"].extend(["Falcon"])
|
||||
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError
|
||||
if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils import dummy_pt_and_triton_objects
|
||||
_import_structure["utils.dummy_pt_and_triton_objects"] = [name for name in dir(dummy_pt_and_triton_objects) if not name.startswith("_")]
|
||||
from .utils import dummy_pt_and_triton_objects
|
||||
_import_structure["utils.dummy_pt_and_triton_objects"] = [name for name in dir(dummy_pt_and_triton_objects) if not name.startswith("_")]
|
||||
else:
|
||||
_import_structure["models.mpt"].extend(["MPT"])
|
||||
_import_structure["models.mpt"].extend(["MPT"])
|
||||
|
||||
try:
|
||||
if not utils.is_torch_available(): raise MissingDependencyError
|
||||
if not utils.is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils import dummy_pt_objects
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
from .utils import dummy_pt_objects
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["FlanT5"])
|
||||
_import_structure["models.dolly_v2"].extend(["DollyV2"])
|
||||
_import_structure["models.starcoder"].extend(["StarCoder"])
|
||||
_import_structure["models.stablelm"].extend(["StableLM"])
|
||||
_import_structure["models.opt"].extend(["OPT"])
|
||||
_import_structure["models.gpt_neox"].extend(["GPTNeoX"])
|
||||
_import_structure["models.llama"].extend(["Llama"])
|
||||
_import_structure["models.auto"].extend(["AutoLLM", "MODEL_MAPPING"])
|
||||
_import_structure["models.flan_t5"].extend(["FlanT5"])
|
||||
_import_structure["models.dolly_v2"].extend(["DollyV2"])
|
||||
_import_structure["models.starcoder"].extend(["StarCoder"])
|
||||
_import_structure["models.stablelm"].extend(["StableLM"])
|
||||
_import_structure["models.opt"].extend(["OPT"])
|
||||
_import_structure["models.gpt_neox"].extend(["GPTNeoX"])
|
||||
_import_structure["models.llama"].extend(["Llama"])
|
||||
_import_structure["models.auto"].extend(["AutoLLM", "MODEL_MAPPING"])
|
||||
|
||||
try:
|
||||
if not utils.is_vllm_available(): raise MissingDependencyError
|
||||
if not utils.is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils import dummy_vllm_objects
|
||||
_import_structure["utils.dummy_vllm_objects"] = [name for name in dir(dummy_vllm_objects) if not name.startswith("_")]
|
||||
from .utils import dummy_vllm_objects
|
||||
_import_structure["utils.dummy_vllm_objects"] = [name for name in dir(dummy_vllm_objects) if not name.startswith("_")]
|
||||
else:
|
||||
_import_structure["models.llama"].extend(["VLLMLlama"])
|
||||
_import_structure["models.auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
|
||||
_import_structure["models.llama"].extend(["VLLMLlama"])
|
||||
_import_structure["models.auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
|
||||
|
||||
try:
|
||||
if not utils.is_flax_available(): raise MissingDependencyError
|
||||
if not utils.is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils import dummy_flax_objects
|
||||
_import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")]
|
||||
from .utils import dummy_flax_objects
|
||||
_import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["FlaxFlanT5"])
|
||||
_import_structure["models.opt"].extend(["FlaxOPT"])
|
||||
_import_structure["models.auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
|
||||
_import_structure["models.flan_t5"].extend(["FlaxFlanT5"])
|
||||
_import_structure["models.opt"].extend(["FlaxOPT"])
|
||||
_import_structure["models.auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
|
||||
|
||||
try:
|
||||
if not utils.is_tf_available(): raise MissingDependencyError
|
||||
if not utils.is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils import dummy_tf_objects
|
||||
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
|
||||
from .utils import dummy_tf_objects
|
||||
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["TFFlanT5"])
|
||||
_import_structure["models.opt"].extend(["TFOPT"])
|
||||
_import_structure["models.auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
|
||||
_import_structure["models.flan_t5"].extend(["TFFlanT5"])
|
||||
_import_structure["models.opt"].extend(["TFOPT"])
|
||||
_import_structure["models.auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
|
||||
|
||||
# declaration for OpenLLM-related modules
|
||||
if t.TYPE_CHECKING:
|
||||
from . import bundle as bundle
|
||||
from . import cli as cli
|
||||
from . import client as client
|
||||
from . import exceptions as exceptions
|
||||
from . import models as models
|
||||
from . import playground as playground
|
||||
from . import serialisation as serialisation
|
||||
from . import testing as testing
|
||||
from . import bundle as bundle
|
||||
from . import cli as cli
|
||||
from . import client as client
|
||||
from . import exceptions as exceptions
|
||||
from . import models as models
|
||||
from . import playground as playground
|
||||
from . import serialisation as serialisation
|
||||
from . import testing as testing
|
||||
|
||||
# Specific types import
|
||||
from ._configuration import LLMConfig as LLMConfig
|
||||
from ._generation import StopOnTokens as StopOnTokens
|
||||
from ._generation import StopSequenceCriteria as StopSequenceCriteria
|
||||
from ._llm import LLM as LLM
|
||||
from ._llm import LLMRunnable as LLMRunnable
|
||||
from ._llm import LLMRunner as LLMRunner
|
||||
from ._llm import Runner as Runner
|
||||
from ._quantisation import infer_quantisation_config as infer_quantisation_config
|
||||
from ._schema import EmbeddingsOutput as EmbeddingsOutput
|
||||
from ._schema import GenerationInput as GenerationInput
|
||||
from ._schema import GenerationOutput as GenerationOutput
|
||||
from ._schema import MetadataOutput as MetadataOutput
|
||||
from ._schema import unmarshal_vllm_outputs as unmarshal_vllm_outputs
|
||||
from .cli.entrypoint import build as build
|
||||
from .cli.entrypoint import import_model as import_model
|
||||
from .cli.entrypoint import list_models as list_models
|
||||
from .cli.entrypoint import start as start
|
||||
from .cli.entrypoint import start_grpc as start_grpc
|
||||
from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING
|
||||
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
|
||||
from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
|
||||
from .models.auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
|
||||
from .models.auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
from .models.auto import AutoConfig as AutoConfig
|
||||
from .models.baichuan import BaichuanConfig as BaichuanConfig
|
||||
from .models.chatglm import ChatGLMConfig as ChatGLMConfig
|
||||
from .models.dolly_v2 import DollyV2Config as DollyV2Config
|
||||
from .models.falcon import FalconConfig as FalconConfig
|
||||
from .models.flan_t5 import FlanT5Config as FlanT5Config
|
||||
from .models.gpt_neox import GPTNeoXConfig as GPTNeoXConfig
|
||||
from .models.llama import LlamaConfig as LlamaConfig
|
||||
from .models.mpt import MPTConfig as MPTConfig
|
||||
from .models.opt import OPTConfig as OPTConfig
|
||||
from .models.stablelm import StableLMConfig as StableLMConfig
|
||||
from .models.starcoder import StarCoderConfig as StarCoderConfig
|
||||
from .serialisation import ggml as ggml
|
||||
from .serialisation import transformers as transformers
|
||||
from .utils import infer_auto_class as infer_auto_class
|
||||
# Specific types import
|
||||
from ._configuration import LLMConfig as LLMConfig
|
||||
from ._generation import StopOnTokens as StopOnTokens
|
||||
from ._generation import StopSequenceCriteria as StopSequenceCriteria
|
||||
from ._llm import LLM as LLM
|
||||
from ._llm import LLMRunnable as LLMRunnable
|
||||
from ._llm import LLMRunner as LLMRunner
|
||||
from ._llm import Runner as Runner
|
||||
from ._quantisation import infer_quantisation_config as infer_quantisation_config
|
||||
from ._schema import EmbeddingsOutput as EmbeddingsOutput
|
||||
from ._schema import GenerationInput as GenerationInput
|
||||
from ._schema import GenerationOutput as GenerationOutput
|
||||
from ._schema import MetadataOutput as MetadataOutput
|
||||
from ._schema import unmarshal_vllm_outputs as unmarshal_vllm_outputs
|
||||
from .cli.entrypoint import build as build
|
||||
from .cli.entrypoint import import_model as import_model
|
||||
from .cli.entrypoint import list_models as list_models
|
||||
from .cli.entrypoint import start as start
|
||||
from .cli.entrypoint import start_grpc as start_grpc
|
||||
from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING
|
||||
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
|
||||
from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
|
||||
from .models.auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
|
||||
from .models.auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
from .models.auto import AutoConfig as AutoConfig
|
||||
from .models.baichuan import BaichuanConfig as BaichuanConfig
|
||||
from .models.chatglm import ChatGLMConfig as ChatGLMConfig
|
||||
from .models.dolly_v2 import DollyV2Config as DollyV2Config
|
||||
from .models.falcon import FalconConfig as FalconConfig
|
||||
from .models.flan_t5 import FlanT5Config as FlanT5Config
|
||||
from .models.gpt_neox import GPTNeoXConfig as GPTNeoXConfig
|
||||
from .models.llama import LlamaConfig as LlamaConfig
|
||||
from .models.mpt import MPTConfig as MPTConfig
|
||||
from .models.opt import OPTConfig as OPTConfig
|
||||
from .models.stablelm import StableLMConfig as StableLMConfig
|
||||
from .models.starcoder import StarCoderConfig as StarCoderConfig
|
||||
from .serialisation import ggml as ggml
|
||||
from .serialisation import transformers as transformers
|
||||
from .utils import infer_auto_class as infer_auto_class
|
||||
# NOTE: torch and cpm_kernels
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_pt_and_cpm_kernels_objects import *
|
||||
else:
|
||||
from .models.baichuan import Baichuan as Baichuan
|
||||
from .models.chatglm import ChatGLM as ChatGLM
|
||||
# NOTE: torch and einops
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_pt_and_einops_objects import *
|
||||
else:
|
||||
from .models.falcon import Falcon as Falcon
|
||||
# NOTE: torch and triton
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_pt_and_triton_objects import *
|
||||
else:
|
||||
from .models.mpt import MPT as MPT
|
||||
try:
|
||||
if not utils.is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_pt_objects import *
|
||||
else:
|
||||
from .models.auto import MODEL_MAPPING as MODEL_MAPPING
|
||||
from .models.auto import AutoLLM as AutoLLM
|
||||
from .models.dolly_v2 import DollyV2 as DollyV2
|
||||
from .models.flan_t5 import FlanT5 as FlanT5
|
||||
from .models.gpt_neox import GPTNeoX as GPTNeoX
|
||||
from .models.llama import Llama as Llama
|
||||
from .models.opt import OPT as OPT
|
||||
from .models.stablelm import StableLM as StableLM
|
||||
from .models.starcoder import StarCoder as StarCoder
|
||||
try:
|
||||
if not utils.is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_vllm_objects import *
|
||||
else:
|
||||
from .models.auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING
|
||||
from .models.auto import AutoVLLM as AutoVLLM
|
||||
from .models.llama import VLLMLlama as VLLMLlama
|
||||
from .models.opt import VLLMOPT as VLLMOPT
|
||||
try:
|
||||
if not utils.is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_flax_objects import *
|
||||
else:
|
||||
from .models.auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING
|
||||
from .models.auto import AutoFlaxLLM as AutoFlaxLLM
|
||||
from .models.flan_t5 import FlaxFlanT5 as FlaxFlanT5
|
||||
from .models.opt import FlaxOPT as FlaxOPT
|
||||
try:
|
||||
if not utils.is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_tf_objects import *
|
||||
else:
|
||||
from .models.auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING
|
||||
from .models.auto import AutoTFLLM as AutoTFLLM
|
||||
from .models.flan_t5 import TFFlanT5 as TFFlanT5
|
||||
from .models.opt import TFOPT as TFOPT
|
||||
|
||||
# NOTE: torch and cpm_kernels
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_pt_and_cpm_kernels_objects import *
|
||||
else:
|
||||
from .models.baichuan import Baichuan as Baichuan
|
||||
from .models.chatglm import ChatGLM as ChatGLM
|
||||
|
||||
# NOTE: torch and einops
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_einops_available()): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_pt_and_einops_objects import *
|
||||
else:
|
||||
from .models.falcon import Falcon as Falcon
|
||||
|
||||
# NOTE: torch and triton
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_triton_available()): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_pt_and_triton_objects import *
|
||||
else:
|
||||
from .models.mpt import MPT as MPT
|
||||
|
||||
try:
|
||||
if not utils.is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_pt_objects import *
|
||||
else:
|
||||
from .models.auto import MODEL_MAPPING as MODEL_MAPPING
|
||||
from .models.auto import AutoLLM as AutoLLM
|
||||
from .models.dolly_v2 import DollyV2 as DollyV2
|
||||
from .models.flan_t5 import FlanT5 as FlanT5
|
||||
from .models.gpt_neox import GPTNeoX as GPTNeoX
|
||||
from .models.llama import Llama as Llama
|
||||
from .models.opt import OPT as OPT
|
||||
from .models.stablelm import StableLM as StableLM
|
||||
from .models.starcoder import StarCoder as StarCoder
|
||||
|
||||
try:
|
||||
if not utils.is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_vllm_objects import *
|
||||
else:
|
||||
from .models.auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING
|
||||
from .models.auto import AutoVLLM as AutoVLLM
|
||||
from .models.llama import VLLMLlama as VLLMLlama
|
||||
from .models.opt import VLLMOPT as VLLMOPT
|
||||
|
||||
try:
|
||||
if not utils.is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_flax_objects import *
|
||||
else:
|
||||
from .models.auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING
|
||||
from .models.auto import AutoFlaxLLM as AutoFlaxLLM
|
||||
from .models.flan_t5 import FlaxFlanT5 as FlaxFlanT5
|
||||
from .models.opt import FlaxOPT as FlaxOPT
|
||||
|
||||
try:
|
||||
if not utils.is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
from .utils.dummy_tf_objects import *
|
||||
else:
|
||||
from .models.auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING
|
||||
from .models.auto import AutoTFLLM as AutoTFLLM
|
||||
from .models.flan_t5 import TFFlanT5 as TFFlanT5
|
||||
from .models.opt import TFOPT as TFOPT
|
||||
|
||||
else: sys.modules[__name__] = utils.LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, doc=__doc__,
|
||||
extra_objects={
|
||||
else:
|
||||
sys.modules[__name__] = utils.LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"], _import_structure, module_spec=__spec__, doc=__doc__, extra_objects={
|
||||
# The below is a special mapping that allows openllm to be used as a dictionary.
|
||||
# This is purely for convenience sake, and should not be used in performance critcal
|
||||
# code. This is also not considered as a public API.
|
||||
"__openllm_special__": {"flax": "AutoFlaxLLM", "tf": "AutoTFLLM", "pt": "AutoLLM", "vllm": "AutoVLLM"},
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@@ -20,5 +20,5 @@ To start any OpenLLM model:
|
||||
openllm start <model_name> --options ...
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
from openllm.cli.entrypoint import cli
|
||||
cli()
|
||||
from openllm.cli.entrypoint import cli
|
||||
cli()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generation utilities to be reused throughout."""
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
@@ -19,29 +18,12 @@ import typing as t
|
||||
import transformers
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
import torch
|
||||
|
||||
class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
"""This class used to stop generation when a seq of tokens are met.
|
||||
|
||||
Args:
|
||||
stop_sequences: `str` or `list[str]` of the sequence (list of sequences) on which to stop execution.
|
||||
tokenizer: Tokenizer to be used to decode the model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer):
|
||||
if isinstance(stop_sequences, str):
|
||||
stop_sequences = [stop_sequences]
|
||||
self.stop_sequences = stop_sequences
|
||||
self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **attrs: t.Any) -> bool:
|
||||
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
|
||||
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||
|
||||
|
||||
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer):
|
||||
if isinstance(stop_sequences, str): stop_sequences = [stop_sequences]
|
||||
self.stop_sequences,self.tokenizer = stop_sequences, tokenizer
|
||||
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **attrs: t.Any) -> bool: return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||
class StopOnTokens(transformers.StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: t.Any) -> bool:
|
||||
stop_ids = {50278, 50279, 50277, 1, 0}
|
||||
return t.cast(int, input_ids[0][-1]) in stop_ids
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: t.Any) -> bool: return t.cast(int, input_ids[0][-1]) in {50278, 50279, 50277, 1, 0}
|
||||
|
||||
1829
src/openllm/_llm.py
1829
src/openllm/_llm.py
File diff suppressed because it is too large
Load Diff
@@ -16,33 +16,25 @@ import string
|
||||
import typing as t
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0:
|
||||
raise ValueError("Positional arguments are not supported")
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
"""Check if extra params is passed."""
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras:
|
||||
raise KeyError(f"Extra params passed: {extras}")
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
"""Extract template variables from a template string."""
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0: raise ValueError("Positional arguments are not supported")
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras: raise KeyError(f"Extra params passed: {extras}")
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]: return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
|
||||
def process_prompt(prompt: str, template: str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str:
|
||||
# Currently, all default prompt will always have `instruction` key.
|
||||
if not use_prompt_template: return prompt
|
||||
elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'")
|
||||
template_variables = default_formatter.extract_template_variables(template)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
|
||||
try: return template.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e: raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None
|
||||
# Currently, all default prompt will always have `instruction` key.
|
||||
if not use_prompt_template: return prompt
|
||||
elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'")
|
||||
template_variables = default_formatter.extract_template_variables(template)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
|
||||
try:
|
||||
return template.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None
|
||||
|
||||
@@ -25,119 +25,78 @@ from .utils import pkg
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import auto_gptq as autogptq
|
||||
import torch
|
||||
import auto_gptq as autogptq
|
||||
import torch
|
||||
|
||||
import openllm
|
||||
import transformers
|
||||
import openllm
|
||||
import transformers
|
||||
|
||||
from ._types import DictStrAny
|
||||
from ._types import DictStrAny
|
||||
else:
|
||||
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QuantiseMode = t.Literal["int8", "int4", "gptq"]
|
||||
|
||||
|
||||
# fmt: off
|
||||
@overload
|
||||
def infer_quantisation_config(
|
||||
cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["int8", "int4"], **attrs: t.Any
|
||||
) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
|
||||
def infer_quantisation_config(cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["int8", "int4"], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]: ...
|
||||
@overload
|
||||
def infer_quantisation_config(
|
||||
cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["gptq"], **attrs: t.Any
|
||||
) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
...
|
||||
def infer_quantisation_config(cls: type[openllm.LLM[t.Any, t.Any]], quantise: t.Literal["gptq"], **attrs: t.Any) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]: ...
|
||||
# fmt: on
|
||||
def infer_quantisation_config(cls: type[openllm.LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
|
||||
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
|
||||
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
|
||||
|
||||
autogptq_attrs: DictStrAny = {"bits": attrs.pop("gptq_bits", 4), "group_size": attrs.pop("gptq_group_size", -1), "damp_percent": attrs.pop("gptq_damp_percent", 0.01), "desc_act": attrs.pop("gptq_desc_act", True), "sym": attrs.pop("gptq_sym", True), "true_sequential": attrs.pop("gptq_true_sequential", True),}
|
||||
|
||||
def infer_quantisation_config(
|
||||
cls: type[openllm.LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any
|
||||
) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
|
||||
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
|
||||
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
|
||||
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
|
||||
if int8_skip_modules is None: int8_skip_modules = []
|
||||
if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm":
|
||||
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
|
||||
int8_skip_modules.append("lm_head")
|
||||
return transformers.BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload, llm_int8_threshhold=int8_threshold, llm_int8_skip_modules=int8_skip_modules, llm_int8_has_fp16_weight=int8_has_fp16_weight,)
|
||||
|
||||
autogptq_attrs: DictStrAny = {
|
||||
"bits": attrs.pop("gptq_bits", 4),
|
||||
"group_size": attrs.pop("gptq_group_size", -1),
|
||||
"damp_percent": attrs.pop("gptq_damp_percent", 0.01),
|
||||
"desc_act": attrs.pop("gptq_desc_act", True),
|
||||
"sym": attrs.pop("gptq_sym", True),
|
||||
"true_sequential": attrs.pop("gptq_true_sequential", True),
|
||||
}
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16)
|
||||
int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4")
|
||||
int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True)
|
||||
|
||||
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
|
||||
if int8_skip_modules is None:
|
||||
int8_skip_modules = []
|
||||
if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm":
|
||||
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
|
||||
int8_skip_modules.append("lm_head")
|
||||
return transformers.BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight,
|
||||
)
|
||||
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16)
|
||||
int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4")
|
||||
int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True)
|
||||
|
||||
# NOTE: Quantization setup
|
||||
# quantize is a openllm.LLM feature, where we can quantize the model
|
||||
# with bitsandbytes or quantization aware training.
|
||||
if not is_bitsandbytes_available():
|
||||
raise RuntimeError(
|
||||
"Quantization requires bitsandbytes to be installed. Make "
|
||||
"sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
if quantise == "int8":
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == "int4":
|
||||
if is_transformers_supports_kbit():
|
||||
quantisation_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=int4_compute_dtype,
|
||||
bnb_4bit_quant_type=int4_quant_type,
|
||||
bnb_4bit_use_double_quant=int4_use_double_quant,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"'quantize' is set to int4, while the current transformers version %s does not support "
|
||||
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
|
||||
"make sure to install the latest version of transformers either via PyPI or "
|
||||
"from git source: 'pip install git+https://github.com/huggingface/transformers'.",
|
||||
pkg.pkg_version_info("transformers"),
|
||||
)
|
||||
logger.warning("OpenLLM will fallback to 8-bit quantization.")
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == "gptq":
|
||||
if not is_autogptq_available():
|
||||
logger.warning(
|
||||
"'quantize=\"gptq\"' requires 'auto-gptq' to be installed (not available with local environment)."
|
||||
" Make sure to have 'auto-gptq' available locally: 'pip install \"openllm[gptq]\"'. OpenLLM will fallback "
|
||||
"to int8 with bitsandbytes."
|
||||
)
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
else:
|
||||
quantisation_config = autogptq.BaseQuantizeConfig(**autogptq_attrs)
|
||||
# NOTE: Quantization setup
|
||||
# quantize is a openllm.LLM feature, where we can quantize the model
|
||||
# with bitsandbytes or quantization aware training.
|
||||
if not is_bitsandbytes_available(): raise RuntimeError("Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantise == "int8": quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == "int4":
|
||||
if is_transformers_supports_kbit(): quantisation_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=int4_compute_dtype, bnb_4bit_quant_type=int4_quant_type, bnb_4bit_use_double_quant=int4_use_double_quant,)
|
||||
else:
|
||||
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.")
|
||||
|
||||
return quantisation_config, attrs
|
||||
logger.warning(
|
||||
"'quantize' is set to int4, while the current transformers version %s does not support "
|
||||
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
|
||||
"make sure to install the latest version of transformers either via PyPI or "
|
||||
"from git source: 'pip install git+https://github.com/huggingface/transformers'.", pkg.pkg_version_info("transformers"),
|
||||
)
|
||||
logger.warning("OpenLLM will fallback to 8-bit quantization.")
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == "gptq":
|
||||
if not is_autogptq_available():
|
||||
logger.warning("'quantize=\"gptq\"' requires 'auto-gptq' to be installed (not available with local environment)."
|
||||
" Make sure to have 'auto-gptq' available locally: 'pip install \"openllm[gptq]\"'. OpenLLM will fallback "
|
||||
"to int8 with bitsandbytes.")
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
else:
|
||||
quantisation_config = autogptq.BaseQuantizeConfig(**autogptq_attrs)
|
||||
else:
|
||||
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.")
|
||||
return quantisation_config, attrs
|
||||
|
||||
@@ -27,109 +27,71 @@ from .utils import bentoml_cattr
|
||||
from .utils import requires_dependencies
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import vllm
|
||||
import vllm
|
||||
|
||||
from ._types import DictStrAny
|
||||
from ._types import DictStrAny
|
||||
else:
|
||||
DictStrAny = dict
|
||||
vllm = LazyLoader("vllm", globals(), "vllm")
|
||||
|
||||
DictStrAny = dict
|
||||
vllm = LazyLoader("vllm", globals(), "vllm")
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class GenerationInput:
|
||||
prompt: str
|
||||
"""The prompt to be sent to system."""
|
||||
prompt: str
|
||||
"""The prompt to be sent to system."""
|
||||
llm_config: LLMConfig
|
||||
"""A mapping of given LLM configuration values for given system."""
|
||||
@staticmethod
|
||||
def convert_llm_config(data: dict[str, t.Any] | LLMConfig, cls: type[LLMConfig] | None = None) -> LLMConfig:
|
||||
if isinstance(data, LLMConfig): return data
|
||||
elif LazyType(DictStrAny).isinstance(data):
|
||||
if cls is None: raise ValueError("'cls' must pass if given data is a dictionary.")
|
||||
return cls(**data)
|
||||
else:
|
||||
raise RuntimeError(f"Type {type(data)} is not yet supported.")
|
||||
|
||||
llm_config: LLMConfig
|
||||
"""A mapping of given LLM configuration values for given system."""
|
||||
|
||||
@staticmethod
|
||||
def convert_llm_config(data: dict[str, t.Any] | LLMConfig, cls: type[LLMConfig] | None = None) -> LLMConfig:
|
||||
if isinstance(data, LLMConfig):
|
||||
return data
|
||||
elif LazyType(DictStrAny).isinstance(data):
|
||||
if cls is None:
|
||||
raise ValueError("'cls' must pass if given data is a dictionary.")
|
||||
return cls(**data)
|
||||
else:
|
||||
raise RuntimeError(f"Type {type(data)} is not yet supported.")
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]:
|
||||
from .models.auto import AutoConfig
|
||||
|
||||
llm_config = AutoConfig.for_model(model_name, **attrs)
|
||||
return attr.make_class(
|
||||
inflection.camelize(llm_config["model_name"]) + "GenerationInput",
|
||||
attrs={
|
||||
"prompt": attr.field(type=str),
|
||||
"llm_config": attr.field(
|
||||
type=llm_config.__class__,
|
||||
default=llm_config,
|
||||
converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def model_dump(self) -> dict[str, t.Any]:
|
||||
return {"prompt": self.prompt, "llm_config": self.llm_config.model_dump(flatten=True)}
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]:
|
||||
from .models.auto import AutoConfig
|
||||
llm_config = AutoConfig.for_model(model_name, **attrs)
|
||||
return attr.make_class(inflection.camelize(llm_config["model_name"]) + "GenerationInput", attrs={"prompt": attr.field(type=str), "llm_config": attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__))})
|
||||
|
||||
def model_dump(self) -> dict[str, t.Any]:
|
||||
return {"prompt": self.prompt, "llm_config": self.llm_config.model_dump(flatten=True)}
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class GenerationOutput:
|
||||
responses: t.List[t.Any]
|
||||
"""A list of responses from the system."""
|
||||
responses: t.List[t.Any]
|
||||
"""A list of responses from the system."""
|
||||
configuration: t.Dict[str, t.Any]
|
||||
"""A mapping of configuration values for given system."""
|
||||
@property
|
||||
def marshaled_config(self) -> GenerationConfig:
|
||||
return bentoml_cattr.structure(self.configuration, GenerationConfig)
|
||||
|
||||
configuration: t.Dict[str, t.Any]
|
||||
"""A mapping of configuration values for given system."""
|
||||
|
||||
@property
|
||||
def marshaled_config(self) -> GenerationConfig:
|
||||
return bentoml_cattr.structure(self.configuration, GenerationConfig)
|
||||
|
||||
@property
|
||||
def unmarshaled(self) -> dict[str, t.Any]:
|
||||
return bentoml_cattr.unstructure(self)
|
||||
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
if hasattr(self, key): return getattr(self, key)
|
||||
elif key in self.configuration: return self.configuration[key]
|
||||
else: raise KeyError(key)
|
||||
@property
|
||||
def unmarshaled(self) -> dict[str, t.Any]:
|
||||
return bentoml_cattr.unstructure(self)
|
||||
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
if hasattr(self, key): return getattr(self, key)
|
||||
elif key in self.configuration: return self.configuration[key]
|
||||
else: raise KeyError(key)
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class MetadataOutput:
|
||||
model_id: str
|
||||
timeout: int
|
||||
model_name: str
|
||||
framework: str
|
||||
configuration: str
|
||||
supports_embeddings: bool
|
||||
supports_hf_agent: bool
|
||||
|
||||
model_id: str
|
||||
timeout: int
|
||||
model_name: str
|
||||
framework: str
|
||||
configuration: str
|
||||
supports_embeddings: bool
|
||||
supports_hf_agent: bool
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class EmbeddingsOutput:
|
||||
embeddings: t.List[t.List[float]]
|
||||
num_tokens: int
|
||||
|
||||
embeddings: t.List[t.List[float]]
|
||||
num_tokens: int
|
||||
|
||||
@requires_dependencies("vllm", extra="vllm")
|
||||
def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> DictStrAny:
|
||||
return dict(
|
||||
request_id=request_output.request_id,
|
||||
prompt=request_output.prompt,
|
||||
finished=request_output.finished,
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
outputs=[
|
||||
dict(
|
||||
index=it.index,
|
||||
text=it.text,
|
||||
token_ids=it.token_ids,
|
||||
cumulative_logprob=it.cumulative_logprob,
|
||||
logprobs=it.logprobs,
|
||||
finish_reason=it.finish_reason,
|
||||
)
|
||||
for it in request_output.outputs
|
||||
],
|
||||
)
|
||||
return dict(request_id=request_output.request_id, prompt=request_output.prompt, finished=request_output.finished, prompt_token_ids=request_output.prompt_token_ids, outputs=[dict(index=it.index, text=it.text, token_ids=it.token_ids, cumulative_logprob=it.cumulative_logprob, logprobs=it.logprobs, finish_reason=it.finish_reason) for it in request_output.outputs],)
|
||||
|
||||
@@ -25,8 +25,8 @@ from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
if t.TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
# The following warnings from bitsandbytes, and probably not that important for users to see
|
||||
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
|
||||
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
|
||||
@@ -36,37 +36,59 @@ adapter_map = os.environ.get("OPENLLM_ADAPTER_MAP", """{__model_adapter_map__}""
|
||||
llm_config = openllm.AutoConfig.for_model(model)
|
||||
runner = openllm.Runner(model, llm_config=llm_config, ensure_available=False, adapter_map=orjson.loads(adapter_map))
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[runner])
|
||||
|
||||
@svc.api(input=bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True)}), output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)}), route="/v1/generate")
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
qa_inputs = openllm.GenerationInput.for_model(model)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, **config)
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
qa_inputs = openllm.GenerationInput.for_model(model)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, **config)
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
|
||||
@svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample({"model_id": runner.llm.model_id, "timeout": 3600, "model_name": llm_config["model_name"], "framework": "pt", "configuration": "", "supports_embeddings": runner.supports_embeddings, "supports_hf_agent": runner.supports_hf_agent}), route="/v1/metadata")
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput: return openllm.MetadataOutput(model_id=runner.llm.model_id, timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent,)
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return openllm.MetadataOutput(model_id=runner.llm.model_id, timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent,)
|
||||
|
||||
if runner.supports_embeddings:
|
||||
@svc.api(input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20}), route="/v1/embeddings")
|
||||
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
|
||||
responses = await runner.embeddings.async_run(phrases)
|
||||
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"])
|
||||
|
||||
@svc.api(
|
||||
input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({
|
||||
"embeddings": [
|
||||
0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918,
|
||||
0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076
|
||||
], "num_tokens": 20
|
||||
}), route="/v1/embeddings"
|
||||
)
|
||||
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
|
||||
responses = await runner.embeddings.async_run(phrases)
|
||||
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"])
|
||||
|
||||
if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
|
||||
@attr.define
|
||||
class HfAgentInput:
|
||||
inputs: str
|
||||
parameters: t.Dict[str, t.Any]
|
||||
async def hf_agent(request: Request) -> Response:
|
||||
json_str = await request.body()
|
||||
try: input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), HfAgentInput)
|
||||
except orjson.JSONDecodeError as err: raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
|
||||
stop = input_data.parameters.pop("stop", ["\n"])
|
||||
try: return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
|
||||
except NotImplementedError: return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
|
||||
hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])])
|
||||
svc.mount_asgi_app(hf_app, path="/hf")
|
||||
|
||||
@attr.define
|
||||
class HfAgentInput:
|
||||
inputs: str
|
||||
parameters: t.Dict[str, t.Any]
|
||||
|
||||
async def hf_agent(request: Request) -> Response:
|
||||
json_str = await request.body()
|
||||
try:
|
||||
input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), HfAgentInput)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
|
||||
stop = input_data.parameters.pop("stop", ["\n"])
|
||||
try:
|
||||
return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
|
||||
except NotImplementedError:
|
||||
return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
|
||||
|
||||
hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])])
|
||||
svc.mount_asgi_app(hf_app, path="/hf")
|
||||
|
||||
async def list_adapter_v1(_: Request) -> Response:
|
||||
res: dict[str, t.Any] = {}
|
||||
if runner.peft_adapters["success"] is True: res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()}
|
||||
res.update({"success": runner.peft_adapters["success"], "error_msg": runner.peft_adapters["error_msg"]})
|
||||
return JSONResponse(res, status_code=200)
|
||||
res: dict[str, t.Any] = {}
|
||||
if runner.peft_adapters["success"] is True: res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()}
|
||||
res.update({"success": runner.peft_adapters["success"], "error_msg": runner.peft_adapters["error_msg"]})
|
||||
return JSONResponse(res, status_code=200)
|
||||
|
||||
adapters_app_v1 = Starlette(debug=True, routes=[Route("/adapters", list_adapter_v1, methods=["GET"])])
|
||||
svc.mount_asgi_app(adapters_app_v1, path="/v1")
|
||||
|
||||
@@ -35,223 +35,220 @@ from .utils import LazyType
|
||||
from .utils import ReprMixin
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
ListIntStr = list[int | str]
|
||||
class DynResource(bentoml.Resource[t.List[str]], resource_id=""):
|
||||
resource_id: t.ClassVar[str]
|
||||
ListIntStr = list[int | str]
|
||||
|
||||
class DynResource(bentoml.Resource[t.List[str]], resource_id=""):
|
||||
resource_id: t.ClassVar[str]
|
||||
|
||||
else:
|
||||
DynResource = bentoml.Resource[t.List[str]]
|
||||
ListIntStr = list
|
||||
DynResource = bentoml.Resource[t.List[str]]
|
||||
ListIntStr = list
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
from typing_extensions import overload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _strtoul(s: str) -> int:
|
||||
"""Return -1 or positive integer sequence string starts with,."""
|
||||
if not s: return -1
|
||||
idx = 0
|
||||
for idx, c in enumerate(s):
|
||||
if not (c.isdigit() or (idx == 0 and c in "+-")): break
|
||||
if idx + 1 == len(s): idx += 1 # noqa: PLW2901
|
||||
# NOTE: idx will be set via enumerate
|
||||
return int(s[:idx]) if idx > 0 else -1
|
||||
|
||||
"""Return -1 or positive integer sequence string starts with,."""
|
||||
if not s: return -1
|
||||
idx = 0
|
||||
for idx, c in enumerate(s):
|
||||
if not (c.isdigit() or (idx == 0 and c in "+-")): break
|
||||
if idx + 1 == len(s): idx += 1 # noqa: PLW2901
|
||||
# NOTE: idx will be set via enumerate
|
||||
return int(s[:idx]) if idx > 0 else -1
|
||||
|
||||
def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
|
||||
rcs: list[str] = []
|
||||
for elem in lst.split(","):
|
||||
# Repeated id results in empty set
|
||||
if elem in rcs: return []
|
||||
# Anything other but prefix is ignored
|
||||
if not elem.startswith(prefix): break
|
||||
rcs.append(elem)
|
||||
return rcs
|
||||
|
||||
rcs: list[str] = []
|
||||
for elem in lst.split(","):
|
||||
# Repeated id results in empty set
|
||||
if elem in rcs: return []
|
||||
# Anything other but prefix is ignored
|
||||
if not elem.startswith(prefix): break
|
||||
rcs.append(elem)
|
||||
return rcs
|
||||
|
||||
_STACK_LEVEL = 3
|
||||
|
||||
@overload
|
||||
def _parse_visible_devices(default_var: str | None = ..., respect_env: t.Literal[True] = True) -> list[str] | None: ...
|
||||
def _parse_visible_devices(default_var: str | None = ..., respect_env: t.Literal[True] = True) -> list[str] | None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _parse_visible_devices(default_var: str = ..., respect_env: t.Literal[False] = ...) -> list[str]: ...
|
||||
def _parse_visible_devices(default_var: str = ..., respect_env: t.Literal[False] = ...) -> list[str]:
|
||||
...
|
||||
|
||||
def _parse_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None:
|
||||
"""CUDA_VISIBLE_DEVICES aware with default var for parsing spec."""
|
||||
if respect_env:
|
||||
spec = os.getenv("CUDA_VISIBLE_DEVICES", default_var)
|
||||
if not spec: return None
|
||||
else:
|
||||
if default_var is None: raise ValueError("spec is required to be not None when parsing spec.")
|
||||
spec = default_var
|
||||
"""CUDA_VISIBLE_DEVICES aware with default var for parsing spec."""
|
||||
if respect_env:
|
||||
spec = os.getenv("CUDA_VISIBLE_DEVICES", default_var)
|
||||
if not spec: return None
|
||||
else:
|
||||
if default_var is None: raise ValueError("spec is required to be not None when parsing spec.")
|
||||
spec = default_var
|
||||
|
||||
if spec.startswith("GPU-"): return _parse_list_with_prefix(spec, "GPU-")
|
||||
if spec.startswith("MIG-"): return _parse_list_with_prefix(spec, "MIG-")
|
||||
|
||||
# XXX: We to somehow handle cases such as '100m'
|
||||
# CUDA_VISIBLE_DEVICES uses something like strtoul
|
||||
# which makes `1gpu2,2ampere` is equivalent to `1,2`
|
||||
rc: list[int] = []
|
||||
for el in spec.split(","):
|
||||
x = _strtoul(el.strip())
|
||||
# Repeated ordinal results in empty set
|
||||
if x in rc: return []
|
||||
# Negative value aborts the sequence
|
||||
if x < 0: break
|
||||
rc.append(x)
|
||||
return [str(i) for i in rc]
|
||||
if spec.startswith("GPU-"): return _parse_list_with_prefix(spec, "GPU-")
|
||||
if spec.startswith("MIG-"): return _parse_list_with_prefix(spec, "MIG-")
|
||||
|
||||
# XXX: We to somehow handle cases such as '100m'
|
||||
# CUDA_VISIBLE_DEVICES uses something like strtoul
|
||||
# which makes `1gpu2,2ampere` is equivalent to `1,2`
|
||||
rc: list[int] = []
|
||||
for el in spec.split(","):
|
||||
x = _strtoul(el.strip())
|
||||
# Repeated ordinal results in empty set
|
||||
if x in rc: return []
|
||||
# Negative value aborts the sequence
|
||||
if x < 0: break
|
||||
rc.append(x)
|
||||
return [str(i) for i in rc]
|
||||
|
||||
def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
"""Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation.
|
||||
"""Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation.
|
||||
|
||||
It relies on torch.cuda implementation and in turns respect CUDA_VISIBLE_DEVICES.
|
||||
"""
|
||||
visible_devices = _parse_visible_devices()
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == "amd.com/gpu":
|
||||
if not psutil.LINUX:
|
||||
if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
|
||||
return []
|
||||
It relies on torch.cuda implementation and in turns respect CUDA_VISIBLE_DEVICES.
|
||||
"""
|
||||
visible_devices = _parse_visible_devices()
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == "amd.com/gpu":
|
||||
if not psutil.LINUX:
|
||||
if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
|
||||
return []
|
||||
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
# So we need to use the ctypes bindings directly.
|
||||
# we don't want to use CLI because parsing is a pain.
|
||||
sys.path.append("/opt/rocm/libexec/rocm_smi")
|
||||
try:
|
||||
from ctypes import byref
|
||||
from ctypes import c_uint32
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
# So we need to use the ctypes bindings directly.
|
||||
# we don't want to use CLI because parsing is a pain.
|
||||
sys.path.append("/opt/rocm/libexec/rocm_smi")
|
||||
try:
|
||||
from ctypes import byref
|
||||
from ctypes import c_uint32
|
||||
|
||||
# refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py
|
||||
from rsmiBindings import rocmsmi
|
||||
from rsmiBindings import rsmi_status_t
|
||||
|
||||
device_count = c_uint32(0)
|
||||
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)]
|
||||
return []
|
||||
# In this case the binary is not found, returning empty list
|
||||
except (ModuleNotFoundError, ImportError): return []
|
||||
finally: sys.path.remove("/opt/rocm/libexec/rocm_smi")
|
||||
else:
|
||||
try:
|
||||
from cuda import cuda
|
||||
cuda.cuInit(0)
|
||||
_, dev = cuda.cuDeviceGetCount()
|
||||
return [str(i) for i in range(dev)]
|
||||
except (ImportError, RuntimeError, AttributeError): return []
|
||||
return visible_devices
|
||||
# refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py
|
||||
from rsmiBindings import rocmsmi
|
||||
from rsmiBindings import rsmi_status_t
|
||||
|
||||
device_count = c_uint32(0)
|
||||
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)]
|
||||
return []
|
||||
# In this case the binary is not found, returning empty list
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
return []
|
||||
finally:
|
||||
sys.path.remove("/opt/rocm/libexec/rocm_smi")
|
||||
else:
|
||||
try:
|
||||
from cuda import cuda
|
||||
cuda.cuInit(0)
|
||||
_, dev = cuda.cuDeviceGetCount()
|
||||
return [str(i) for i in range(dev)]
|
||||
except (ImportError, RuntimeError, AttributeError):
|
||||
return []
|
||||
return visible_devices
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: int) -> list[str]: ...
|
||||
def _from_spec(cls: type[DynResource], spec: int) -> list[str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: ListIntStr) -> list[str]: ...
|
||||
def _from_spec(cls: type[DynResource], spec: ListIntStr) -> list[str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: str) -> list[str]: ...
|
||||
def _from_spec(cls: type[DynResource], spec: str) -> list[str]:
|
||||
...
|
||||
|
||||
def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]:
|
||||
"""Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation.
|
||||
|
||||
The parser behaves similar to how PyTorch handles CUDA_VISIBLE_DEVICES. This means within
|
||||
BentoML's resource configuration, its behaviour is similar to CUDA_VISIBLE_DEVICES.
|
||||
"""
|
||||
if isinstance(spec, int):
|
||||
if spec in (-1, 0): return []
|
||||
if spec < -1: raise ValueError("Spec cannot be < -1.")
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
if not spec: return []
|
||||
if spec.isdigit(): spec = ",".join([str(i) for i in range(_strtoul(spec))])
|
||||
return _parse_visible_devices(spec, respect_env=False)
|
||||
elif LazyType(ListIntStr).isinstance(spec): return [str(x) for x in spec]
|
||||
else: raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
|
||||
"""Shared mixin implementation for OpenLLM's NVIDIA and AMD resource implementation.
|
||||
|
||||
The parser behaves similar to how PyTorch handles CUDA_VISIBLE_DEVICES. This means within
|
||||
BentoML's resource configuration, its behaviour is similar to CUDA_VISIBLE_DEVICES.
|
||||
"""
|
||||
if isinstance(spec, int):
|
||||
if spec in (-1, 0): return []
|
||||
if spec < -1: raise ValueError("Spec cannot be < -1.")
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
if not spec: return []
|
||||
if spec.isdigit(): spec = ",".join([str(i) for i in range(_strtoul(spec))])
|
||||
return _parse_visible_devices(spec, respect_env=False)
|
||||
elif LazyType(ListIntStr).isinstance(spec):
|
||||
return [str(x) for x in spec]
|
||||
else:
|
||||
raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
|
||||
|
||||
def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
"""Return list of device UUID as reported by NVML or None if NVML discovery/initialization failed."""
|
||||
from ctypes import CDLL
|
||||
from ctypes import byref
|
||||
from ctypes import c_int
|
||||
from ctypes import c_void_p
|
||||
from ctypes import create_string_buffer
|
||||
"""Return list of device UUID as reported by NVML or None if NVML discovery/initialization failed."""
|
||||
from ctypes import CDLL
|
||||
from ctypes import byref
|
||||
from ctypes import c_int
|
||||
from ctypes import c_void_p
|
||||
from ctypes import create_string_buffer
|
||||
|
||||
try: nvml_h = CDLL("libnvidia-ml.so.1")
|
||||
except Exception:
|
||||
warnings.warn("Failed to find nvidia binding", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
try:
|
||||
nvml_h = CDLL("libnvidia-ml.so.1")
|
||||
except Exception:
|
||||
warnings.warn("Failed to find nvidia binding", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
|
||||
rc = nvml_h.nvmlInit()
|
||||
rc = nvml_h.nvmlInit()
|
||||
if rc != 0:
|
||||
warnings.warn("Can't initialize NVML", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
dev_count = c_int(-1)
|
||||
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
||||
if rc != 0:
|
||||
warnings.warn("Failed to get available device from system.", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
uuids: list[str] = []
|
||||
for idx in range(dev_count.value):
|
||||
dev_id = c_void_p()
|
||||
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
|
||||
if rc != 0:
|
||||
warnings.warn("Can't initialize NVML", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
dev_count = c_int(-1)
|
||||
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
||||
warnings.warn(f"Failed to get device handle for {idx}", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
buf_len = 96
|
||||
buf = create_string_buffer(buf_len)
|
||||
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
|
||||
if rc != 0:
|
||||
warnings.warn("Failed to get available device from system.", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
uuids: list[str] = []
|
||||
for idx in range(dev_count.value):
|
||||
dev_id = c_void_p()
|
||||
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
|
||||
if rc != 0:
|
||||
warnings.warn(f"Failed to get device handle for {idx}", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
buf_len = 96
|
||||
buf = create_string_buffer(buf_len)
|
||||
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
|
||||
if rc != 0:
|
||||
warnings.warn(f"Failed to get device UUID for {idx}", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
uuids.append(buf.raw.decode("ascii").strip("\0"))
|
||||
del nvml_h
|
||||
return uuids
|
||||
|
||||
warnings.warn(f"Failed to get device UUID for {idx}", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
uuids.append(buf.raw.decode("ascii").strip("\0"))
|
||||
del nvml_h
|
||||
return uuids
|
||||
|
||||
def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
if cls.resource_id == "amd.com/gpu":
|
||||
raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'")
|
||||
if not all(isinstance(i, str) for i in val): raise ValueError("Input list should be all string type.")
|
||||
if cls.resource_id == "amd.com/gpu":
|
||||
raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'")
|
||||
if not all(isinstance(i, str) for i in val): raise ValueError("Input list should be all string type.")
|
||||
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Failed to initialise CUDA runtime binding.")
|
||||
# correctly parse handle
|
||||
for el in val:
|
||||
if el.startswith("GPU-") or el.startswith("MIG-"):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None: raise ValueError("Failed to parse available GPUs UUID")
|
||||
if el not in uuids: raise ValueError(f"Given UUID {el} is not found with available UUID (available: {uuids})")
|
||||
elif el.isdigit():
|
||||
err, _ = cuda.cuDeviceGet(int(el))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f"Failed to get device {el}")
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Failed to initialise CUDA runtime binding.")
|
||||
# correctly parse handle
|
||||
for el in val:
|
||||
if el.startswith("GPU-") or el.startswith("MIG-"):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None: raise ValueError("Failed to parse available GPUs UUID")
|
||||
if el not in uuids: raise ValueError(f"Given UUID {el} is not found with available UUID (available: {uuids})")
|
||||
elif el.isdigit():
|
||||
err, _ = cuda.cuDeviceGet(int(el))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f"Failed to get device {el}")
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[DynResource]:
|
||||
return types.new_class(
|
||||
name,
|
||||
(DynResource, ReprMixin),
|
||||
{"resource_id": resource_kind},
|
||||
lambda ns: ns.update(
|
||||
{
|
||||
"resource_id": resource_kind,
|
||||
"from_spec": classmethod(_from_spec),
|
||||
"from_system": classmethod(_from_system),
|
||||
"validate": classmethod(_validate),
|
||||
"__repr_keys__": property(lambda _: {"resource_id"}),
|
||||
"__doc__": inspect.cleandoc(docstring),
|
||||
"__module__": "openllm._strategies",
|
||||
}
|
||||
),
|
||||
)
|
||||
return types.new_class(
|
||||
name, (DynResource, ReprMixin), {"resource_id": resource_kind}, lambda ns: ns.update({"resource_id": resource_kind, "from_spec": classmethod(_from_spec), "from_system": classmethod(_from_system), "validate": classmethod(_validate), "__repr_keys__": property(lambda _: {"resource_id"}), "__doc__": inspect.cleandoc(docstring), "__module__": "openllm._strategies",}),
|
||||
)
|
||||
|
||||
# NOTE: we need to hint these t.Literal since mypy is to dumb to infer this as literal :facepalm:
|
||||
_TPU_RESOURCE: t.Literal["cloud-tpus.google.com/v2"] = "cloud-tpus.google.com/v2"
|
||||
@@ -259,159 +256,146 @@ _AMD_GPU_RESOURCE: t.Literal["amd.com/gpu"] = "amd.com/gpu"
|
||||
_NVIDIA_GPU_RESOURCE: t.Literal["nvidia.com/gpu"] = "nvidia.com/gpu"
|
||||
_CPU_RESOURCE: t.Literal["cpu"] = "cpu"
|
||||
|
||||
NvidiaGpuResource = _make_resource_class(
|
||||
"NvidiaGpuResource",
|
||||
_NVIDIA_GPU_RESOURCE,
|
||||
"""NVIDIA GPU resource.
|
||||
NvidiaGpuResource = _make_resource_class("NvidiaGpuResource", _NVIDIA_GPU_RESOURCE, """NVIDIA GPU resource.
|
||||
|
||||
This is a modified version of internal's BentoML's NvidiaGpuResource
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.""",
|
||||
)
|
||||
AmdGpuResource = _make_resource_class(
|
||||
"AmdGpuResource",
|
||||
_AMD_GPU_RESOURCE,
|
||||
"""AMD GPU resource.
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.""",)
|
||||
AmdGpuResource = _make_resource_class("AmdGpuResource", _AMD_GPU_RESOURCE, """AMD GPU resource.
|
||||
|
||||
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""",
|
||||
)
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""",)
|
||||
|
||||
LiteralResourceSpec = t.Literal["cloud-tpus.google.com/v2", "amd.com/gpu", "nvidia.com/gpu", "cpu"]
|
||||
|
||||
# convenient mapping
|
||||
def resource_spec(name: t.Literal["tpu", "amd", "nvidia", "cpu"]) -> LiteralResourceSpec:
|
||||
if name == "tpu": return _TPU_RESOURCE
|
||||
elif name == "amd": return _AMD_GPU_RESOURCE
|
||||
elif name == "nvidia": return _NVIDIA_GPU_RESOURCE
|
||||
elif name == "cpu": return _CPU_RESOURCE
|
||||
else: raise ValueError("Unknown alias. Accepted: ['tpu', 'amd', 'nvidia', 'cpu']")
|
||||
def resource_spec(name: t.Literal["tpu", "amd", "nvidia", "cpu"]) -> LiteralResourceSpec:
|
||||
if name == "tpu": return _TPU_RESOURCE
|
||||
elif name == "amd": return _AMD_GPU_RESOURCE
|
||||
elif name == "nvidia": return _NVIDIA_GPU_RESOURCE
|
||||
elif name == "cpu": return _CPU_RESOURCE
|
||||
else: raise ValueError("Unknown alias. Accepted: ['tpu', 'amd', 'nvidia', 'cpu']")
|
||||
|
||||
@functools.lru_cache
|
||||
def available_resource_spec() -> tuple[LiteralResourceSpec, ...]:
|
||||
"""This is a utility function helps to determine the available resources from given running system.
|
||||
"""This is a utility function helps to determine the available resources from given running system.
|
||||
|
||||
It will first check for TPUs -> AMD GPUS -> NVIDIA GPUS -> CPUs.
|
||||
|
||||
TODO: Supports TPUs
|
||||
"""
|
||||
available = ()
|
||||
if len(AmdGpuResource.from_system()) > 0: available += (_AMD_GPU_RESOURCE,)
|
||||
if len(NvidiaGpuResource.from_system()) > 0: available += (_NVIDIA_GPU_RESOURCE,)
|
||||
available += (_CPU_RESOURCE,)
|
||||
return t.cast(t.Tuple[LiteralResourceSpec, ...], available)
|
||||
It will first check for TPUs -> AMD GPUS -> NVIDIA GPUS -> CPUs.
|
||||
|
||||
TODO: Supports TPUs
|
||||
"""
|
||||
available = ()
|
||||
if len(AmdGpuResource.from_system()) > 0: available += (_AMD_GPU_RESOURCE,)
|
||||
if len(NvidiaGpuResource.from_system()) > 0: available += (_NVIDIA_GPU_RESOURCE,)
|
||||
available += (_CPU_RESOURCE,)
|
||||
return t.cast(t.Tuple[LiteralResourceSpec, ...], available)
|
||||
|
||||
class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
"""This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
|
||||
"""This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
|
||||
|
||||
It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU.
|
||||
See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices
|
||||
for ROCm's support for CUDA_VISIBLE_DEVICES.
|
||||
It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU.
|
||||
See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices
|
||||
for ROCm's support for CUDA_VISIBLE_DEVICES.
|
||||
|
||||
TODO: Support CloudTPUResource
|
||||
TODO: Support CloudTPUResource
|
||||
"""
|
||||
@classmethod
|
||||
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float) -> int:
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
|
||||
def _get_gpu_count(typ: list[str] | None, kind: str) -> int | None:
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: return math.ceil(len(typ) * workers_per_resource)
|
||||
|
||||
# use NVIDIA
|
||||
kind = "nvidia.com/gpu"
|
||||
count = _get_gpu_count(get_resource(resource_request, kind), kind)
|
||||
if count: return count
|
||||
# use AMD
|
||||
kind = "amd.com/gpu"
|
||||
count = _get_gpu_count(get_resource(resource_request, kind, validate=False), kind)
|
||||
if count: return count
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, "cpu")
|
||||
if cpus is not None and cpus > 0:
|
||||
if "cpu" not in runnable_class.SUPPORTED_RESOURCES: logger.warning("No known supported resource available for %s, falling back to using CPU.", runnable_class)
|
||||
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError("Fractional CPU multi threading support is not yet supported.")
|
||||
return int(workers_per_resource)
|
||||
return math.ceil(cpus) * workers_per_resource
|
||||
|
||||
# this should not be reached by user since we always read system resource as default
|
||||
raise ValueError(f"No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.")
|
||||
|
||||
@classmethod
|
||||
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
|
||||
"""Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
runnable_class: The runnable class to be run.
|
||||
resource_request: The resource request of the runnable.
|
||||
workers_per_resource: # of workers per resource.
|
||||
worker_index: The index of the worker, start from 0.
|
||||
"""
|
||||
@classmethod
|
||||
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float) -> int:
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
|
||||
def _get_gpu_count(typ: list[str] | None, kind: str) -> int | None:
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: return math.ceil(len(typ) * workers_per_resource)
|
||||
|
||||
# use NVIDIA
|
||||
kind = "nvidia.com/gpu"
|
||||
count = _get_gpu_count(get_resource(resource_request, kind), kind)
|
||||
if count: return count
|
||||
|
||||
# use AMD
|
||||
kind = "amd.com/gpu"
|
||||
count = _get_gpu_count(get_resource(resource_request, kind, validate=False), kind)
|
||||
if count: return count
|
||||
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, "cpu")
|
||||
if cpus is not None and cpus > 0:
|
||||
if "cpu" not in runnable_class.SUPPORTED_RESOURCES: logger.warning("No known supported resource available for %s, falling back to using CPU.", runnable_class)
|
||||
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError("Fractional CPU multi threading support is not yet supported.")
|
||||
return int(workers_per_resource)
|
||||
|
||||
return math.ceil(cpus) * workers_per_resource
|
||||
|
||||
# this should not be reached by user since we always read system resource as default
|
||||
raise ValueError(f"No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.")
|
||||
@classmethod
|
||||
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
|
||||
"""Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
runnable_class: The runnable class to be run.
|
||||
resource_request: The resource request of the runnable.
|
||||
workers_per_resource: # of workers per resource.
|
||||
worker_index: The index of the worker, start from 0.
|
||||
"""
|
||||
cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
disabled = cuda_env in ("", "-1")
|
||||
|
||||
environ: dict[str, t.Any] = {}
|
||||
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
|
||||
# use NVIDIA
|
||||
kind = "nvidia.com/gpu"
|
||||
typ = get_resource(resource_request, kind)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
|
||||
return environ
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
return environ
|
||||
|
||||
# use AMD
|
||||
kind = "amd.com/gpu"
|
||||
typ = get_resource(resource_request, kind, validate=False)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
|
||||
return environ
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
return environ
|
||||
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, "cpu")
|
||||
if cpus is not None and cpus > 0:
|
||||
environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
thread_count = math.ceil(cpus)
|
||||
for thread_env in THREAD_ENVS: environ[thread_env] = os.getenv(thread_env, str(thread_count))
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
return environ
|
||||
for thread_env in THREAD_ENVS: environ[thread_env] = os.getenv(thread_env, "1")
|
||||
return environ
|
||||
cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
disabled = cuda_env in ("", "-1")
|
||||
environ: dict[str, t.Any] = {}
|
||||
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = "nvidia.com/gpu"
|
||||
typ = get_resource(resource_request, kind)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
|
||||
return environ
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
return environ
|
||||
# use AMD
|
||||
kind = "amd.com/gpu"
|
||||
typ = get_resource(resource_request, kind, validate=False)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
|
||||
return environ
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
return environ
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, "cpu")
|
||||
if cpus is not None and cpus > 0:
|
||||
environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
thread_count = math.ceil(cpus)
|
||||
for thread_env in THREAD_ENVS:
|
||||
environ[thread_env] = os.getenv(thread_env, str(thread_count))
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
return environ
|
||||
for thread_env in THREAD_ENVS:
|
||||
environ[thread_env] = os.getenv(thread_env, "1")
|
||||
return environ
|
||||
return environ
|
||||
|
||||
@staticmethod
|
||||
def transpile_workers_to_cuda_envvar(workers_per_resource: float | int, gpus: list[str], worker_index: int) -> str:
|
||||
# Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.
|
||||
if isinstance(workers_per_resource, float):
|
||||
# NOTE: We hit this branch when workers_per_resource is set to
|
||||
# float, for example 0.5 or 0.25
|
||||
if workers_per_resource > 1:
|
||||
raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.")
|
||||
# We are round the assigned resource here. This means if workers_per_resource=.4
|
||||
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
|
||||
assigned_resource_per_worker = round(1 / workers_per_resource)
|
||||
if len(gpus) < assigned_resource_per_worker:
|
||||
logger.warning("Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])", gpus, worker_index, assigned_resource_per_worker)
|
||||
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
|
||||
assigned_gpu = gpus[assigned_resource_per_worker * worker_index : assigned_resource_per_worker * (worker_index + 1) ]
|
||||
dev = ",".join(assigned_gpu)
|
||||
else:
|
||||
idx = worker_index // workers_per_resource
|
||||
if idx >= len(gpus): raise ValueError(f"Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}")
|
||||
dev = str(gpus[idx])
|
||||
return dev
|
||||
@staticmethod
|
||||
def transpile_workers_to_cuda_envvar(workers_per_resource: float | int, gpus: list[str], worker_index: int) -> str:
|
||||
# Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.
|
||||
if isinstance(workers_per_resource, float):
|
||||
# NOTE: We hit this branch when workers_per_resource is set to
|
||||
# float, for example 0.5 or 0.25
|
||||
if workers_per_resource > 1:
|
||||
raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.")
|
||||
# We are round the assigned resource here. This means if workers_per_resource=.4
|
||||
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
|
||||
assigned_resource_per_worker = round(1 / workers_per_resource)
|
||||
if len(gpus) < assigned_resource_per_worker:
|
||||
logger.warning("Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])", gpus, worker_index, assigned_resource_per_worker)
|
||||
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
|
||||
assigned_gpu = gpus[assigned_resource_per_worker * worker_index:assigned_resource_per_worker * (worker_index+1)]
|
||||
dev = ",".join(assigned_gpu)
|
||||
else:
|
||||
idx = worker_index // workers_per_resource
|
||||
if idx >= len(gpus): raise ValueError(f"Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}")
|
||||
dev = str(gpus[idx])
|
||||
return dev
|
||||
|
||||
@@ -19,10 +19,8 @@ It will raises a RuntimeError if this is imported eagerly.
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
|
||||
if not t.TYPE_CHECKING: raise RuntimeError(f"{__name__} should not be imported during runtime")
|
||||
|
||||
|
||||
import attr
|
||||
|
||||
import bentoml
|
||||
@@ -31,17 +29,15 @@ from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict
|
||||
from ._configuration import AdapterType
|
||||
from ._configuration import LiteralRuntime as LiteralRuntime
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
|
||||
import openllm
|
||||
from openllm._llm import M as _M
|
||||
from openllm._llm import T as _T
|
||||
from bentoml._internal.runner.runnable import RunnableMethod
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
from bentoml._internal.runner.strategy import Strategy
|
||||
import peft
|
||||
|
||||
import openllm
|
||||
from openllm._llm import M as _M
|
||||
from openllm._llm import T as _T
|
||||
from bentoml._internal.runner.runnable import RunnableMethod
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
from bentoml._internal.runner.strategy import Strategy
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
DictStrAny = dict[str, t.Any]
|
||||
@@ -54,72 +50,72 @@ T = t.TypeVar("T")
|
||||
Ts = t.TypeVarTuple("Ts")
|
||||
At = t.TypeVar("At", bound=attr.AttrsInstance)
|
||||
|
||||
|
||||
class PeftAdapterOutput(t.TypedDict):
|
||||
success: bool
|
||||
result: dict[str, peft.PeftConfig]
|
||||
error_msg: str
|
||||
|
||||
success: bool
|
||||
result: dict[str, peft.PeftConfig]
|
||||
error_msg: str
|
||||
|
||||
class LLMEmbeddings(t.TypedDict):
|
||||
embeddings: t.List[t.List[float]]
|
||||
num_tokens: int
|
||||
|
||||
embeddings: t.List[t.List[float]]
|
||||
num_tokens: int
|
||||
|
||||
class AdaptersTuple(TupleAny):
|
||||
adapter_id: str
|
||||
name: str | None
|
||||
config: DictStrAny
|
||||
|
||||
adapter_id: str
|
||||
name: str | None
|
||||
config: DictStrAny
|
||||
|
||||
AdaptersMapping = dict[AdapterType, tuple[AdaptersTuple, ...]]
|
||||
|
||||
|
||||
class LLMRunnable(bentoml.Runnable, t.Generic[_M, _T]):
|
||||
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
__call__: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
|
||||
set_adapter: RunnableMethod[LLMRunnable[_M, _T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
|
||||
embeddings: RunnableMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
|
||||
generate_one: RunnableMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnableMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
|
||||
|
||||
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
__call__: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
|
||||
set_adapter: RunnableMethod[LLMRunnable[_M, _T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
|
||||
embeddings: RunnableMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
|
||||
generate_one: RunnableMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnableMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
|
||||
|
||||
class LLMRunner(bentoml.Runner, t.Generic[_M, _T]):
|
||||
__doc__: str
|
||||
__module__: str
|
||||
llm_type: str
|
||||
identifying_params: dict[str, t.Any]
|
||||
llm: openllm.LLM[_M, _T]
|
||||
config: openllm.LLMConfig
|
||||
implementation: LiteralRuntime
|
||||
supports_embeddings: bool
|
||||
supports_hf_agent: bool
|
||||
has_adapters: bool
|
||||
embeddings: RunnerMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnerMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
|
||||
generate_one: RunnerMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnerMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
|
||||
def __init__(
|
||||
self,
|
||||
runnable_class: type[LLMRunnable[_M, _T]],
|
||||
*,
|
||||
runnable_init_params: dict[str, t.Any] | None = ...,
|
||||
name: str | None = ...,
|
||||
scheduling_strategy: type[Strategy] = ...,
|
||||
models: list[bentoml.Model] | None = ...,
|
||||
max_batch_size: int | None = ...,
|
||||
max_latency_ms: int | None = ...,
|
||||
method_configs: dict[str, dict[str, int]] | None = ...,
|
||||
embedded: bool = False,
|
||||
) -> None: ...
|
||||
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ...
|
||||
def run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
def download_model(self) -> bentoml.Model: ...
|
||||
@property
|
||||
def peft_adapters(self) -> PeftAdapterOutput: ...
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]: ...
|
||||
__doc__: str
|
||||
__module__: str
|
||||
llm_type: str
|
||||
identifying_params: dict[str, t.Any]
|
||||
llm: openllm.LLM[_M, _T]
|
||||
config: openllm.LLMConfig
|
||||
implementation: LiteralRuntime
|
||||
supports_embeddings: bool
|
||||
supports_hf_agent: bool
|
||||
has_adapters: bool
|
||||
embeddings: RunnerMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnerMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
|
||||
generate_one: RunnerMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnerMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
|
||||
|
||||
def __init__(
|
||||
self, runnable_class: type[LLMRunnable[_M, _T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
def embed(self, prompt: str | list[str]) -> LLMEmbeddings:
|
||||
...
|
||||
|
||||
def run(self, prompt: str, **attrs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
def download_model(self) -> bentoml.Model:
|
||||
...
|
||||
|
||||
@property
|
||||
def peft_adapters(self) -> PeftAdapterOutput:
|
||||
...
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
...
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Build-related utilities. Some of these utilities are mainly used for 'openllm.build'.
|
||||
|
||||
These utilities will stay internal, and its API can be changed or updated without backward-compatibility.
|
||||
@@ -23,20 +22,18 @@ import typing as t
|
||||
from . import oci as oci
|
||||
from ..utils import LazyModule
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"],
|
||||
"oci": oci.__all__,
|
||||
}
|
||||
_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": oci.__all__,}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from . import _package as _package
|
||||
from ._package import build_editable as build_editable
|
||||
from ._package import construct_docker_options as construct_docker_options
|
||||
from ._package import construct_python_options as construct_python_options
|
||||
from ._package import create_bento as create_bento
|
||||
from .oci import CONTAINER_NAMES as CONTAINER_NAMES
|
||||
from .oci import build_container as build_container
|
||||
from .oci import get_base_container_name as get_base_container_name
|
||||
from .oci import get_base_container_tag as get_base_container_tag
|
||||
from .oci import supported_registries as supported_registries
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from . import _package as _package
|
||||
from ._package import build_editable as build_editable
|
||||
from ._package import construct_docker_options as construct_docker_options
|
||||
from ._package import construct_python_options as construct_python_options
|
||||
from ._package import create_bento as create_bento
|
||||
from .oci import CONTAINER_NAMES as CONTAINER_NAMES
|
||||
from .oci import build_container as build_container
|
||||
from .oci import get_base_container_name as get_base_container_name
|
||||
from .oci import get_base_container_tag as get_base_container_tag
|
||||
from .oci import supported_registries as supported_registries
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -45,200 +45,149 @@ from ..utils import is_torch_available
|
||||
from ..utils import pkg
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
from fs.base import FS
|
||||
|
||||
import openllm
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
import openllm
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
|
||||
from .oci import LiteralContainerRegistry
|
||||
from .oci import LiteralContainerVersionStrategy
|
||||
from .oci import LiteralContainerRegistry
|
||||
from .oci import LiteralContainerVersionStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_DEV_BUILD = "OPENLLM_DEV_BUILD"
|
||||
|
||||
|
||||
def build_editable(path: str) -> str | None:
|
||||
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
|
||||
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != "true": return None
|
||||
# We need to build the package in editable mode, so that we can import it
|
||||
from build import ProjectBuilder
|
||||
from build.env import IsolatedEnvBuilder
|
||||
module_location = pkg.source_locations("openllm")
|
||||
if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.")
|
||||
pyproject_path = Path(module_location).parent.parent / "pyproject.toml"
|
||||
if os.path.isfile(pyproject_path.__fspath__()):
|
||||
logger.info("OpenLLM is installed in editable mode. Generating built wheels...")
|
||||
with IsolatedEnvBuilder() as env:
|
||||
builder = ProjectBuilder(pyproject_path.parent)
|
||||
builder.python_executable = env.executable
|
||||
builder.scripts_dir = env.scripts_dir
|
||||
env.install(builder.build_system_requires)
|
||||
return builder.build("wheel", path, config_settings={"--global-option": "--quiet"})
|
||||
raise RuntimeError("Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.")
|
||||
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
|
||||
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != "true": return None
|
||||
# We need to build the package in editable mode, so that we can import it
|
||||
from build import ProjectBuilder
|
||||
from build.env import IsolatedEnvBuilder
|
||||
module_location = pkg.source_locations("openllm")
|
||||
if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.")
|
||||
pyproject_path = Path(module_location).parent.parent / "pyproject.toml"
|
||||
if os.path.isfile(pyproject_path.__fspath__()):
|
||||
logger.info("OpenLLM is installed in editable mode. Generating built wheels...")
|
||||
with IsolatedEnvBuilder() as env:
|
||||
builder = ProjectBuilder(pyproject_path.parent)
|
||||
builder.python_executable = env.executable
|
||||
builder.scripts_dir = env.scripts_dir
|
||||
env.install(builder.build_system_requires)
|
||||
return builder.build("wheel", path, config_settings={"--global-option": "--quiet"})
|
||||
raise RuntimeError("Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.")
|
||||
|
||||
def construct_python_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
llm_fs: FS,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
) -> PythonOptions:
|
||||
packages = ["openllm", "scipy"] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ["openllm[fine-tune]"]
|
||||
# NOTE: add openllm to the default dependencies
|
||||
# if users has openllm custom built wheels, it will still respect
|
||||
# that since bentoml will always install dependencies from requirements.txt
|
||||
# first, then proceed to install everything inside the wheels/ folder.
|
||||
if extra_dependencies is not None: packages += [f"openllm[{k}]" for k in extra_dependencies]
|
||||
def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str | None] | None = None,) -> PythonOptions:
|
||||
packages = ["openllm", "scipy"] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ["openllm[fine-tune]"]
|
||||
# NOTE: add openllm to the default dependencies
|
||||
# if users has openllm custom built wheels, it will still respect
|
||||
# that since bentoml will always install dependencies from requirements.txt
|
||||
# first, then proceed to install everything inside the wheels/ folder.
|
||||
if extra_dependencies is not None: packages += [f"openllm[{k}]" for k in extra_dependencies]
|
||||
|
||||
req = llm.config["requirements"]
|
||||
if req is not None: packages.extend(req)
|
||||
if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
|
||||
req = llm.config["requirements"]
|
||||
if req is not None: packages.extend(req)
|
||||
if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
env = llm.config["env"]
|
||||
framework_envvar = env["framework_value"]
|
||||
if framework_envvar == "flax":
|
||||
if not is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
|
||||
packages.extend([importlib.metadata.version("flax"), importlib.metadata.version("jax"), importlib.metadata.version("jaxlib")])
|
||||
elif framework_envvar == "tf":
|
||||
if not is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
|
||||
candidates = (
|
||||
"tensorflow",
|
||||
"tensorflow-cpu",
|
||||
"tensorflow-gpu",
|
||||
"tf-nightly",
|
||||
"tf-nightly-cpu",
|
||||
"tf-nightly-gpu",
|
||||
"intel-tensorflow",
|
||||
"intel-tensorflow-avx512",
|
||||
"tensorflow-rocm",
|
||||
"tensorflow-macos",
|
||||
)
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for candidate in candidates:
|
||||
try:
|
||||
pkgver = importlib.metadata.version(candidate)
|
||||
if pkgver == candidate: packages.extend(["tensorflow"])
|
||||
else:
|
||||
_tf_version = importlib.metadata.version(candidate)
|
||||
packages.extend([f"tensorflow>={_tf_version}"])
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError: pass
|
||||
else:
|
||||
if not is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
|
||||
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
|
||||
env = llm.config["env"]
|
||||
framework_envvar = env["framework_value"]
|
||||
if framework_envvar == "flax":
|
||||
if not is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
|
||||
packages.extend([importlib.metadata.version("flax"), importlib.metadata.version("jax"), importlib.metadata.version("jaxlib")])
|
||||
elif framework_envvar == "tf":
|
||||
if not is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
|
||||
candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos",)
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for candidate in candidates:
|
||||
try:
|
||||
pkgver = importlib.metadata.version(candidate)
|
||||
if pkgver == candidate: packages.extend(["tensorflow"])
|
||||
else:
|
||||
_tf_version = importlib.metadata.version(candidate)
|
||||
packages.extend([f"tensorflow>={_tf_version}"])
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass
|
||||
else:
|
||||
if not is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
|
||||
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
|
||||
|
||||
wheels: list[str] = []
|
||||
built_wheels = build_editable(llm_fs.getsyspath("/"))
|
||||
if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}"))
|
||||
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"])
|
||||
wheels: list[str] = []
|
||||
built_wheels = build_editable(llm_fs.getsyspath("/"))
|
||||
if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}"))
|
||||
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"])
|
||||
|
||||
def construct_docker_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
_: FS,
|
||||
workers_per_resource: int | float,
|
||||
quantize: t.LiteralString | None,
|
||||
bettertransformer: bool | None,
|
||||
adapter_map: dict[str, str | None] | None,
|
||||
dockerfile_template: str | None,
|
||||
runtime: t.Literal["ggml", "transformers"],
|
||||
serialisation_format: t.Literal["safetensors", "legacy"],
|
||||
container_registry: LiteralContainerRegistry,
|
||||
llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy,
|
||||
) -> DockerOptions:
|
||||
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options_opts = [
|
||||
"api_server.traffic.timeout=36000", # NOTE: Currently we hardcode this value
|
||||
f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}',
|
||||
f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
|
||||
]
|
||||
_bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts)
|
||||
env: EnvVarMixin = llm.config["env"]
|
||||
env_dict = {
|
||||
env.framework: env.framework_value,
|
||||
env.config: f"'{llm.config.model_dump_json().decode()}'",
|
||||
"OPENLLM_MODEL": llm.config["model_name"],
|
||||
"OPENLLM_SERIALIZATION": serialisation_format,
|
||||
"OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'",
|
||||
"OPENLLM_FAST": str(True),
|
||||
"BENTOML_DEBUG": str(True),
|
||||
"BENTOML_QUIET": str(False),
|
||||
"OPENLLMDEVDEBUG": str(get_debug_mode()),
|
||||
"BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'",
|
||||
env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}", # This is the default BENTO_PATH var
|
||||
}
|
||||
if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options_opts = [
|
||||
"api_server.traffic.timeout=36000", # NOTE: Currently we hardcode this value
|
||||
f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
|
||||
]
|
||||
_bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts)
|
||||
env: EnvVarMixin = llm.config["env"]
|
||||
env_dict = {
|
||||
env.framework: env.framework_value, env.config: f"'{llm.config.model_dump_json().decode()}'", "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format, "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "OPENLLM_FAST": str(True), "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "OPENLLMDEVDEBUG": str(get_debug_mode()),
|
||||
"BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}", # This is the default BENTO_PATH var
|
||||
}
|
||||
if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
|
||||
# We need to handle None separately here, as env from subprocess doesn't accept None value.
|
||||
_env = EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
# We need to handle None separately here, as env from subprocess doesn't accept None value.
|
||||
_env = EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
|
||||
if _env.bettertransformer_value is not None: env_dict[_env.bettertransformer] = str(_env.bettertransformer_value)
|
||||
if _env.quantize_value is not None: env_dict[_env.quantize] = _env.quantize_value
|
||||
env_dict[_env.runtime] = _env.runtime_value
|
||||
return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}",env=env_dict, dockerfile_template=dockerfile_template)
|
||||
if _env.bettertransformer_value is not None: env_dict[_env.bettertransformer] = str(_env.bettertransformer_value)
|
||||
if _env.quantize_value is not None: env_dict[_env.quantize] = _env.quantize_value
|
||||
env_dict[_env.runtime] = _env.runtime_value
|
||||
return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}", env=env_dict, dockerfile_template=dockerfile_template)
|
||||
|
||||
@inject
|
||||
def create_bento(
|
||||
bento_tag: bentoml.Tag,
|
||||
llm_fs: FS,
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
workers_per_resource: str | int | float,
|
||||
quantize: t.LiteralString | None,
|
||||
bettertransformer: bool | None,
|
||||
dockerfile_template: str | None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors",
|
||||
container_registry: LiteralContainerRegistry = "ecr",
|
||||
container_version_strategy: LiteralContainerVersionStrategy = "release",
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release", _bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
) -> bentoml.Bento:
|
||||
framework_envvar = llm.config["env"]["framework_value"]
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"})
|
||||
if adapter_map: labels.update(adapter_map)
|
||||
if isinstance(workers_per_resource, str):
|
||||
if workers_per_resource == "round_robin": workers_per_resource = 1.0
|
||||
elif workers_per_resource == "conserved": workers_per_resource = 1.0 if device_count() == 0 else float(1 / device_count())
|
||||
else:
|
||||
try: workers_per_resource = float(workers_per_resource)
|
||||
except ValueError: raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
|
||||
elif isinstance(workers_per_resource, int): workers_per_resource = float(workers_per_resource)
|
||||
framework_envvar = llm.config["env"]["framework_value"]
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"})
|
||||
if adapter_map: labels.update(adapter_map)
|
||||
if isinstance(workers_per_resource, str):
|
||||
if workers_per_resource == "round_robin": workers_per_resource = 1.0
|
||||
elif workers_per_resource == "conserved": workers_per_resource = 1.0 if device_count() == 0 else float(1 / device_count())
|
||||
else:
|
||||
try:
|
||||
workers_per_resource = float(workers_per_resource)
|
||||
except ValueError:
|
||||
raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
|
||||
elif isinstance(workers_per_resource, int):
|
||||
workers_per_resource = float(workers_per_resource)
|
||||
|
||||
logger.info("Building Bento for '%s'", llm.config["start_name"])
|
||||
# add service.py definition to this temporary folder
|
||||
codegen.write_service(llm, adapter_map, llm_fs)
|
||||
logger.info("Building Bento for '%s'", llm.config["start_name"])
|
||||
# add service.py definition to this temporary folder
|
||||
codegen.write_service(llm, adapter_map, llm_fs)
|
||||
|
||||
llm_spec = ModelSpec.from_item({"tag": str(llm.tag), "alias": llm.tag.name})
|
||||
build_config = BentoBuildConfig(
|
||||
service=f"{llm.config['service_name']}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
description=f"OpenLLM service for {llm.config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map, dockerfile_template, runtime, serialisation_format, container_registry, container_version_strategy),
|
||||
models=[llm_spec],
|
||||
)
|
||||
llm_spec = ModelSpec.from_item({"tag": str(llm.tag), "alias": llm.tag.name})
|
||||
build_config = BentoBuildConfig(
|
||||
service=f"{llm.config['service_name']}:svc", name=bento_tag.name, labels=labels, description=f"OpenLLM service for {llm.config['start_name']}", include=list(llm_fs.walk.files()), exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"], python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map, dockerfile_template, runtime, serialisation_format, container_registry, container_version_strategy), models=[llm_spec],
|
||||
)
|
||||
|
||||
bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath("/"))
|
||||
# NOTE: the model_id_path here are only used for setting this environment variable within the container
|
||||
# built with for BentoLLM.
|
||||
service_fs_path = fs.path.join("src", llm.config["service_name"])
|
||||
service_path = bento._fs.getsyspath(service_fs_path)
|
||||
with open(service_path, "r") as f: service_contents = f.readlines()
|
||||
bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath("/"))
|
||||
# NOTE: the model_id_path here are only used for setting this environment variable within the container
|
||||
# built with for BentoLLM.
|
||||
service_fs_path = fs.path.join("src", llm.config["service_name"])
|
||||
service_path = bento._fs.getsyspath(service_fs_path)
|
||||
with open(service_path, "r") as f:
|
||||
service_contents = f.readlines()
|
||||
|
||||
for it in service_contents:
|
||||
if "__bento_name__" in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
for it in service_contents:
|
||||
if "__bento_name__" in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
|
||||
script = "".join(service_contents)
|
||||
if DEBUG: logger.info("Generated script:\n%s", script)
|
||||
script = "".join(service_contents)
|
||||
if DEBUG: logger.info("Generated script:\n%s", script)
|
||||
|
||||
bento._fs.writetext(service_fs_path, script)
|
||||
if "model_store" in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store)
|
||||
# backward arguments. `model_store` is added recently
|
||||
return bento.save(bento_store=_bento_store)
|
||||
bento._fs.writetext(service_fs_path, script)
|
||||
if "model_store" in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store)
|
||||
# backward arguments. `model_store` is added recently
|
||||
return bento.save(bento_store=_bento_store)
|
||||
|
||||
@@ -42,8 +42,7 @@ LiteralContainerVersionStrategy = t.Literal["release", "nightly", "latest"]
|
||||
# but in the future, we can infer based on git repo and everything to make it more options for users
|
||||
# to build the base image. For now, all of the base image will be <registry>/bentoml/openllm:...
|
||||
# NOTE: The ECR registry is the public one and currently only @bentoml team has access to push it.
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {"docker": "docker.io/bentoml/openllm", "gh": "ghcr.io/bentoml/openllm",
|
||||
"ecr": "public.ecr.aws/y5w8i4y6/bentoml/openllm"}
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {"docker": "docker.io/bentoml/openllm", "gh": "ghcr.io/bentoml/openllm", "ecr": "public.ecr.aws/y5w8i4y6/bentoml/openllm"}
|
||||
|
||||
_URI = "https://github.com/bentoml/openllm.git"
|
||||
|
||||
@@ -51,56 +50,63 @@ _module_location = pkg.source_locations("openllm")
|
||||
|
||||
@functools.lru_cache
|
||||
@apply(str.lower)
|
||||
def get_base_container_name(reg: LiteralContainerRegistry) -> str: return _CONTAINER_REGISTRY[reg]
|
||||
def get_base_container_name(reg: LiteralContainerRegistry) -> str:
|
||||
return _CONTAINER_REGISTRY[reg]
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _git() -> git.cmd.Git: return git.cmd.Git(_URI)
|
||||
def _git() -> git.cmd.Git:
|
||||
return git.cmd.Git(_URI)
|
||||
|
||||
@functools.lru_cache
|
||||
def _nightly_ref() -> tuple[str, str]: return _git().ls_remote(_URI, "main", heads=True).split()
|
||||
def _nightly_ref() -> tuple[str, str]:
|
||||
return _git().ls_remote(_URI, "main", heads=True).split()
|
||||
|
||||
@functools.lru_cache
|
||||
def _stable_ref() -> tuple[str, str]: return max([item.split() for item in _git().ls_remote(_URI, refs=True, tags=True).split("\n")],
|
||||
key = lambda tag: tuple(int(k) for k in tag[-1].replace("refs/tags/v", "").split(".")))
|
||||
def _stable_ref() -> tuple[str, str]:
|
||||
return max([item.split() for item in _git().ls_remote(_URI, refs=True, tags=True).split("\n")], key=lambda tag: tuple(int(k) for k in tag[-1].replace("refs/tags/v", "").split(".")))
|
||||
|
||||
def get_base_container_tag(strategy: LiteralContainerVersionStrategy) -> str:
|
||||
if strategy == "release": return _stable_ref()[-1].replace("refs/tags/v", "") # for stable, we can also use latest, but discouraged
|
||||
elif strategy == "latest": return "latest"
|
||||
elif strategy == "nightly": return f"sha-{_nightly_ref()[0][:7]}" # we prefixed with sha-<git_rev_short> (giv_rev[:7])
|
||||
else: raise ValueError(f"Unknown strategy '{strategy}'. Valid strategies are 'release', 'nightly', and 'latest'")
|
||||
if strategy == "release": return _stable_ref()[-1].replace("refs/tags/v", "") # for stable, we can also use latest, but discouraged
|
||||
elif strategy == "latest": return "latest"
|
||||
elif strategy == "nightly": return f"sha-{_nightly_ref()[0][:7]}" # we prefixed with sha-<git_rev_short> (giv_rev[:7])
|
||||
else: raise ValueError(f"Unknown strategy '{strategy}'. Valid strategies are 'release', 'nightly', and 'latest'")
|
||||
|
||||
def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None, version_strategy: LiteralContainerVersionStrategy = "release", push: bool = False, machine: bool = False) -> dict[str | LiteralContainerRegistry, str]:
|
||||
"""This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed.
|
||||
"""This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed.
|
||||
|
||||
Note that this is useful for debugging or for any users who wish to integrate vertically with OpenLLM. For most users, you should be able to get the image either from GitHub Container Registry or our public ECR registry.
|
||||
"""
|
||||
try:
|
||||
if not _BUILDER.health(): raise Error
|
||||
except (Error, subprocess.CalledProcessError): raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
|
||||
if device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)")
|
||||
if not shutil.which("nvidia-container-runtime"): raise RuntimeError("Make sure to have NVIDIA Container Toolkit setup correctly to compile CUDA kernel in container. See https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html for more information.")
|
||||
if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
|
||||
pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml"
|
||||
if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
|
||||
tags: dict[str | LiteralContainerRegistry, str]
|
||||
if not registries: tags = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy
|
||||
else:
|
||||
registries = [registries] if isinstance(registries, str) else list(registries)
|
||||
tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries}
|
||||
try:
|
||||
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()),
|
||||
push=push, progress="plain" if get_debug_mode() else "auto", quiet=machine)
|
||||
if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip()
|
||||
except Exception as err: raise OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
|
||||
return tags
|
||||
Note that this is useful for debugging or for any users who wish to integrate vertically with OpenLLM. For most users, you should be able to get the image either from GitHub Container Registry or our public ECR registry.
|
||||
"""
|
||||
try:
|
||||
if not _BUILDER.health(): raise Error
|
||||
except (Error, subprocess.CalledProcessError):
|
||||
raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
|
||||
if device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)")
|
||||
if not shutil.which("nvidia-container-runtime"): raise RuntimeError("NVIDIA Container Toolkit is required to compile CUDA kernel in container.")
|
||||
if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
|
||||
pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml"
|
||||
if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
|
||||
if t.TYPE_CHECKING: tags: dict[str | LiteralContainerRegistry, str]
|
||||
if not registries: tags = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy
|
||||
else:
|
||||
registries = [registries] if isinstance(registries, str) else list(registries)
|
||||
tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries}
|
||||
try:
|
||||
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()), push=push, progress="plain" if get_debug_mode() else "auto", quiet=machine)
|
||||
if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip()
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
|
||||
return tags
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
|
||||
supported_registries: list[str]
|
||||
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
|
||||
supported_registries: list[str]
|
||||
__all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries"]
|
||||
def __dir__() -> list[str]: return sorted(__all__)
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == "supported_registries": return functools.lru_cache(1)(lambda _: list(_CONTAINER_REGISTRY))()
|
||||
elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY
|
||||
elif name in __all__: return importlib.import_module("." + name, __name__)
|
||||
else: raise AttributeError(f"{name} does not exists under {__name__}")
|
||||
if name == "supported_registries": return functools.lru_cache(1)(lambda _: list(_CONTAINER_REGISTRY))()
|
||||
elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY
|
||||
elif name in __all__: return importlib.import_module("." + name, __name__)
|
||||
else: raise AttributeError(f"{name} does not exists under {__name__}")
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""OpenLLM CLI.
|
||||
|
||||
For more information see ``openllm -h``.
|
||||
|
||||
@@ -47,71 +47,64 @@ from ..utils import is_peft_available
|
||||
from ..utils import resolve_user_filepath
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import subprocess
|
||||
import subprocess
|
||||
|
||||
from .._configuration import LLMConfig
|
||||
from .._types import DictStrAny
|
||||
from .._types import P
|
||||
TupleStr = tuple[str, ...]
|
||||
else: TupleStr = tuple
|
||||
from .._configuration import LLMConfig
|
||||
from .._types import DictStrAny
|
||||
from .._types import P
|
||||
TupleStr = tuple[str, ...]
|
||||
else:
|
||||
TupleStr = tuple
|
||||
|
||||
LiteralOutput = t.Literal["json", "pretty", "porcelain"]
|
||||
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
def parse_config_options(
|
||||
config: LLMConfig,
|
||||
server_timeout: int,
|
||||
workers_per_resource: float,
|
||||
device: tuple[str, ...] | None,
|
||||
environ: DictStrAny,
|
||||
) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options_opts = ["tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}']
|
||||
if device:
|
||||
if len(device) > 1: _bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
|
||||
else: _bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_env += " " if _bentoml_config_options_env else "" + " ".join(_bentoml_config_options_opts)
|
||||
environ["BENTOML_CONFIG_OPTIONS"] = _bentoml_config_options_env
|
||||
return environ
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: tuple[str, ...] | None, environ: DictStrAny,) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options_opts = ["tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}']
|
||||
if device:
|
||||
if len(device) > 1: _bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
|
||||
else: _bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_env += " " if _bentoml_config_options_env else "" + " ".join(_bentoml_config_options_opts)
|
||||
environ["BENTOML_CONFIG_OPTIONS"] = _bentoml_config_options_env
|
||||
return environ
|
||||
|
||||
_adapter_mapping_key = "adapter_map"
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> None:
|
||||
if not value: return None
|
||||
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
|
||||
for v in value:
|
||||
adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
|
||||
# try to resolve the full path if users pass in relative,
|
||||
# currently only support one level of resolve path with current directory
|
||||
try: adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError: pass
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
return None
|
||||
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> None:
|
||||
if not value: return None
|
||||
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
|
||||
for v in value:
|
||||
adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
|
||||
# try to resolve the full path if users pass in relative,
|
||||
# currently only support one level of resolve path with current directory
|
||||
try:
|
||||
adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
return None
|
||||
|
||||
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
|
||||
"""Generate a 'click.Command' for any given LLM.
|
||||
"""Generate a 'click.Command' for any given LLM.
|
||||
|
||||
Args:
|
||||
group: the target ``click.Group`` to save this LLM cli under
|
||||
model: The name of the model or the ``bentoml.Bento`` instance.
|
||||
Args:
|
||||
group: the target ``click.Group`` to save this LLM cli under
|
||||
model: The name of the model or the ``bentoml.Bento`` instance.
|
||||
|
||||
Returns:
|
||||
The click.Command for starting the model server
|
||||
Returns:
|
||||
The click.Command for starting the model server
|
||||
|
||||
Note that the internal commands will return the llm_config and a boolean determine
|
||||
whether the server is run with GPU or not.
|
||||
"""
|
||||
llm_config = AutoConfig.for_model(model)
|
||||
Note that the internal commands will return the llm_config and a boolean determine
|
||||
whether the server is run with GPU or not.
|
||||
"""
|
||||
llm_config = AutoConfig.for_model(model)
|
||||
|
||||
command_attrs: DictStrAny = dict(
|
||||
name=llm_config["model_name"],
|
||||
context_settings=_context_settings or termui.CONTEXT_SETTINGS,
|
||||
short_help=f"Start a LLMServer for '{model}'",
|
||||
aliases=[llm_config["start_name"]] if llm_config["name_type"] == "dasherize" else None,
|
||||
help=f"""\
|
||||
command_attrs: DictStrAny = dict(
|
||||
name=llm_config["model_name"], context_settings=_context_settings or termui.CONTEXT_SETTINGS, short_help=f"Start a LLMServer for '{model}'", aliases=[llm_config["start_name"]] if llm_config["name_type"] == "dasherize" else None, help=f"""\
|
||||
{llm_config['env'].start_docstring}
|
||||
|
||||
\b
|
||||
@@ -130,142 +123,125 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
\b
|
||||
{orjson.dumps(llm_config['model_ids'], option=orjson.OPT_INDENT_2).decode()}
|
||||
""",
|
||||
)
|
||||
)
|
||||
|
||||
if llm_config["requires_gpu"] and device_count() < 1:
|
||||
# NOTE: The model requires GPU, therefore we will return a dummy command
|
||||
command_attrs.update({"short_help": "(Disabled because there is no GPU available)", "help": f"""{model} is currently not available to run on your local machine because it requires GPU for inference."""})
|
||||
return noop_command(group, llm_config, _serve_grpc, **command_attrs)
|
||||
if llm_config["requires_gpu"] and device_count() < 1:
|
||||
# NOTE: The model requires GPU, therefore we will return a dummy command
|
||||
command_attrs.update({"short_help": "(Disabled because there is no GPU available)", "help": f"""{model} is currently not available to run on your local machine because it requires GPU for inference."""})
|
||||
return noop_command(group, llm_config, _serve_grpc, **command_attrs)
|
||||
|
||||
@group.command(**command_attrs)
|
||||
@start_decorator(llm_config, serve_grpc=_serve_grpc)
|
||||
@click.pass_context
|
||||
def start_cmd(
|
||||
ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, workers_per_resource: t.Literal["conserved", "round_robin"] | t.LiteralString, device: tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool,
|
||||
serialisation_format: t.Literal["safetensors", "legacy"], adapter_id: str | None, return_process: bool, **attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
fast = str(fast).upper() in ENV_VARS_TRUE_VALUES
|
||||
if serialisation_format == "safetensors" and quantize is not None and os.getenv("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in ENV_VARS_TRUE_VALUES:
|
||||
termui.echo(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.", fg="yellow")
|
||||
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
|
||||
config, server_attrs = llm_config.model_validate_click(**attrs)
|
||||
server_timeout = first_not_none(server_timeout, default=config["timeout"])
|
||||
server_attrs.update({"working_dir": os.path.dirname(os.path.dirname(__file__)), "timeout": server_timeout})
|
||||
if _serve_grpc: server_attrs["grpc_protocol_version"] = "v1"
|
||||
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
development = server_attrs.pop("development")
|
||||
server_attrs.setdefault("production", not development)
|
||||
wpr = first_not_none(workers_per_resource, default=config["workers_per_resource"])
|
||||
|
||||
@group.command(**command_attrs)
|
||||
@start_decorator(llm_config, serve_grpc=_serve_grpc)
|
||||
@click.pass_context
|
||||
def start_cmd(
|
||||
ctx: click.Context, /,
|
||||
server_timeout: int,
|
||||
model_id: str | None,
|
||||
model_version: str | None,
|
||||
workers_per_resource: t.Literal["conserved", "round_robin"] | t.LiteralString,
|
||||
device: tuple[str, ...],
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None,
|
||||
bettertransformer: bool | None,
|
||||
runtime: t.Literal["ggml", "transformers"],
|
||||
fast: bool,
|
||||
serialisation_format: t.Literal["safetensors", "legacy"],
|
||||
adapter_id: str | None,
|
||||
return_process: bool,
|
||||
**attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
fast = str(fast).upper() in ENV_VARS_TRUE_VALUES
|
||||
if serialisation_format == "safetensors" and quantize is not None and os.getenv("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in ENV_VARS_TRUE_VALUES:
|
||||
termui.echo(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.", fg="yellow")
|
||||
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
|
||||
config, server_attrs = llm_config.model_validate_click(**attrs)
|
||||
server_timeout = first_not_none(server_timeout, default=config["timeout"])
|
||||
server_attrs.update({"working_dir": os.path.dirname(os.path.dirname(__file__)), "timeout": server_timeout})
|
||||
if _serve_grpc: server_attrs["grpc_protocol_version"] = "v1"
|
||||
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
development = server_attrs.pop("development")
|
||||
server_attrs.setdefault("production", not development)
|
||||
wpr = first_not_none(workers_per_resource, default=config["workers_per_resource"])
|
||||
|
||||
if isinstance(wpr, str):
|
||||
if wpr == "round_robin": wpr = 1.0
|
||||
elif wpr == "conserved":
|
||||
if device and device_count() == 0:
|
||||
termui.echo("--device will have no effect as there is no GPUs available", fg="yellow")
|
||||
wpr = 1.0
|
||||
else:
|
||||
available_gpu = len(device) if device else device_count()
|
||||
wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu)
|
||||
else: wpr = float(wpr)
|
||||
elif isinstance(wpr, int): wpr = float(wpr)
|
||||
|
||||
# Create a new model env to work with the envvar during CLI invocation
|
||||
env = EnvVarMixin(config["model_name"], config.default_implementation(), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
prerequisite_check(ctx, config, quantize, adapter_map, int(1 / wpr))
|
||||
|
||||
# NOTE: This is to set current configuration
|
||||
start_env = os.environ.copy()
|
||||
start_env = parse_config_options(config, server_timeout, wpr, device, start_env)
|
||||
if fast: termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg="yellow")
|
||||
|
||||
start_env.update(
|
||||
{
|
||||
"OPENLLM_MODEL": model,
|
||||
"BENTOML_DEBUG": str(get_debug_mode()),
|
||||
"BENTOML_HOME": os.getenv("BENTOML_HOME", BentoMLContainer.bentoml_home.get()),
|
||||
"OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(),
|
||||
"OPENLLM_SERIALIZATION": serialisation_format,
|
||||
env.runtime: env.runtime_value,
|
||||
env.framework: env.framework_value,
|
||||
}
|
||||
)
|
||||
if env.model_id_value: start_env[env.model_id] = str(env.model_id_value)
|
||||
# NOTE: quantize and bettertransformer value is already assigned within env
|
||||
if bettertransformer is not None: start_env[env.bettertransformer] = str(env.bettertransformer_value)
|
||||
if quantize is not None: start_env[env.quantize] = str(env.quantize_value)
|
||||
|
||||
llm = infer_auto_class(env.framework_value).for_model(model, model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format)
|
||||
start_env.update({env.config: llm.config.model_dump_json().decode(), env.model_id: llm.model_id})
|
||||
|
||||
server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs)
|
||||
analytics.track_start_init(llm.config)
|
||||
|
||||
def next_step(model_name: str, adapter_map: DictStrAny | None) -> None:
|
||||
cmd_name = f"openllm build {model_name}"
|
||||
if adapter_map is not None: cmd_name += " " + " ".join([f"--adapter-id {s}" for s in [f"{p}:{name}" if name not in (None, "default") else p for p, name in adapter_map.items()]])
|
||||
if not get_quiet_mode(): termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg="blue")
|
||||
|
||||
if return_process:
|
||||
server.start(env=start_env, text=True)
|
||||
if server.process is None: raise click.ClickException("Failed to start the server.")
|
||||
return server.process
|
||||
if isinstance(wpr, str):
|
||||
if wpr == "round_robin": wpr = 1.0
|
||||
elif wpr == "conserved":
|
||||
if device and device_count() == 0:
|
||||
termui.echo("--device will have no effect as there is no GPUs available", fg="yellow")
|
||||
wpr = 1.0
|
||||
else:
|
||||
try: server.start(env=start_env, text=True, blocking=True)
|
||||
except KeyboardInterrupt: next_step(model, adapter_map)
|
||||
except Exception as err: termui.echo(f"Error caught while running LLM Server:\n{err}", fg="red")
|
||||
else: next_step(model, adapter_map)
|
||||
available_gpu = len(device) if device else device_count()
|
||||
wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu)
|
||||
else:
|
||||
wpr = float(wpr)
|
||||
elif isinstance(wpr, int):
|
||||
wpr = float(wpr)
|
||||
|
||||
# NOTE: Return the configuration for telemetry purposes.
|
||||
return config
|
||||
# Create a new model env to work with the envvar during CLI invocation
|
||||
env = EnvVarMixin(config["model_name"], config.default_implementation(), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
prerequisite_check(ctx, config, quantize, adapter_map, int(1 / wpr))
|
||||
|
||||
return start_cmd
|
||||
# NOTE: This is to set current configuration
|
||||
start_env = os.environ.copy()
|
||||
start_env = parse_config_options(config, server_timeout, wpr, device, start_env)
|
||||
if fast: termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg="yellow")
|
||||
|
||||
start_env.update({"OPENLLM_MODEL": model, "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_HOME": os.getenv("BENTOML_HOME", BentoMLContainer.bentoml_home.get()), "OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(), "OPENLLM_SERIALIZATION": serialisation_format, env.runtime: env.runtime_value, env.framework: env.framework_value,})
|
||||
if env.model_id_value: start_env[env.model_id] = str(env.model_id_value)
|
||||
# NOTE: quantize and bettertransformer value is already assigned within env
|
||||
if bettertransformer is not None: start_env[env.bettertransformer] = str(env.bettertransformer_value)
|
||||
if quantize is not None: start_env[env.quantize] = str(env.quantize_value)
|
||||
|
||||
llm = infer_auto_class(env.framework_value).for_model(model, model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format)
|
||||
start_env.update({env.config: llm.config.model_dump_json().decode(), env.model_id: llm.model_id})
|
||||
|
||||
server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs)
|
||||
analytics.track_start_init(llm.config)
|
||||
|
||||
def next_step(model_name: str, adapter_map: DictStrAny | None) -> None:
|
||||
cmd_name = f"openllm build {model_name}"
|
||||
if adapter_map is not None: cmd_name += " " + " ".join([f"--adapter-id {s}" for s in [f"{p}:{name}" if name not in (None, "default") else p for p, name in adapter_map.items()]])
|
||||
if not get_quiet_mode(): termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg="blue")
|
||||
|
||||
if return_process:
|
||||
server.start(env=start_env, text=True)
|
||||
if server.process is None: raise click.ClickException("Failed to start the server.")
|
||||
return server.process
|
||||
else:
|
||||
try:
|
||||
server.start(env=start_env, text=True, blocking=True)
|
||||
except KeyboardInterrupt:
|
||||
next_step(model, adapter_map)
|
||||
except Exception as err:
|
||||
termui.echo(f"Error caught while running LLM Server:\n{err}", fg="red")
|
||||
else:
|
||||
next_step(model, adapter_map)
|
||||
|
||||
# NOTE: Return the configuration for telemetry purposes.
|
||||
return config
|
||||
|
||||
return start_cmd
|
||||
|
||||
def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, **command_attrs: t.Any) -> click.Command:
|
||||
context_settings = command_attrs.pop("context_settings", {})
|
||||
context_settings.update({"ignore_unknown_options": True, "allow_extra_args": True})
|
||||
command_attrs["context_settings"] = context_settings
|
||||
# NOTE: The model requires GPU, therefore we will return a dummy command
|
||||
@group.command(**command_attrs)
|
||||
def noop(**_: t.Any) -> LLMConfig:
|
||||
termui.echo("No GPU available, therefore this command is disabled", fg="red")
|
||||
analytics.track_start_init(llm_config)
|
||||
return llm_config
|
||||
return noop
|
||||
context_settings = command_attrs.pop("context_settings", {})
|
||||
context_settings.update({"ignore_unknown_options": True, "allow_extra_args": True})
|
||||
command_attrs["context_settings"] = context_settings
|
||||
# NOTE: The model requires GPU, therefore we will return a dummy command
|
||||
@group.command(**command_attrs)
|
||||
def noop(**_: t.Any) -> LLMConfig:
|
||||
termui.echo("No GPU available, therefore this command is disabled", fg="red")
|
||||
analytics.track_start_init(llm_config)
|
||||
return llm_config
|
||||
|
||||
return noop
|
||||
|
||||
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: t.LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
|
||||
if adapter_map and not is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantize and llm_config.default_implementation() == "vllm": ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config.env['framework']}=\"pt\"' to run with quantization.")
|
||||
requirements = llm_config["requirements"]
|
||||
if requirements is not None and len(requirements) > 0:
|
||||
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
|
||||
if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow")
|
||||
if adapter_map and not is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantize and llm_config.default_implementation() == "vllm": ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config.env['framework']}=\"pt\"' to run with quantization.")
|
||||
requirements = llm_config["requirements"]
|
||||
if requirements is not None and len(requirements) > 0:
|
||||
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
|
||||
if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow")
|
||||
|
||||
def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
return lambda fn: compose(*[
|
||||
llm_config.to_click_options,
|
||||
_http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
|
||||
model_id_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
cog.optgroup.option("--server-timeout", type=int, default=None, help="Server timeout in seconds"),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
fast_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
"LLM Optimization Options",
|
||||
help="""Optimization related options.
|
||||
return lambda fn: compose(
|
||||
*[
|
||||
llm_config.to_click_options, _http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
|
||||
model_id_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
cog.optgroup.option("--server-timeout", type=int, default=None, help="Server timeout in seconds"),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
fast_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
"LLM Optimization Options", help="""Optimization related options.
|
||||
|
||||
OpenLLM supports running model with [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/),
|
||||
k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
@@ -275,15 +251,14 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
""",
|
||||
),
|
||||
cog.optgroup.option("--device", type=dantic.CUDA, multiple=True, envvar="CUDA_VISIBLE_DEVICES", callback=parse_device_callback, help=f"Assign GPU devices (if available) for {llm_config['model_name']}.", show_envvar=True),
|
||||
cog.optgroup.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers."),
|
||||
quantize_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
bettertransformer_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
"Fine-tuning related options",
|
||||
help="""\
|
||||
),
|
||||
cog.optgroup.option("--device", type=dantic.CUDA, multiple=True, envvar="CUDA_VISIBLE_DEVICES", callback=parse_device_callback, help=f"Assign GPU devices (if available) for {llm_config['model_name']}.", show_envvar=True),
|
||||
cog.optgroup.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers."),
|
||||
quantize_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
bettertransformer_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
"Fine-tuning related options", help="""\
|
||||
Note that the argument `--adapter-id` can accept the following format:
|
||||
|
||||
- `--adapter-id /path/to/adapter` (local adapter)
|
||||
@@ -298,18 +273,19 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
|
||||
```
|
||||
""",
|
||||
),
|
||||
cog.optgroup.option("--adapter-id", default=None, help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'", multiple=True, callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"),
|
||||
click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True),
|
||||
])(fn)
|
||||
),
|
||||
cog.optgroup.option("--adapter-id", default=None, help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'", multiple=True, callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"),
|
||||
click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True),
|
||||
]
|
||||
)(fn)
|
||||
|
||||
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> TupleStr | None:
|
||||
if value is None: return value
|
||||
if not LazyType(TupleStr).isinstance(value): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})")
|
||||
el: TupleStr = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == "all": return tuple(map(str, available_devices()))
|
||||
return el
|
||||
if value is None: return value
|
||||
if not LazyType(TupleStr).isinstance(value): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})")
|
||||
el: TupleStr = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == "all": return tuple(map(str, available_devices()))
|
||||
return el
|
||||
|
||||
# NOTE: A list of bentoml option that is not needed for parsing.
|
||||
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
|
||||
@@ -317,67 +293,78 @@ def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tup
|
||||
_IGNORED_OPTIONS = {"working_dir", "production", "protocol_version"}
|
||||
|
||||
def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
|
||||
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`."""
|
||||
from bentoml_cli.cli import cli
|
||||
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`."""
|
||||
from bentoml_cli.cli import cli
|
||||
|
||||
command = "serve" if not serve_grpc else "serve-grpc"
|
||||
group = cog.optgroup.group(
|
||||
f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options",
|
||||
help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",
|
||||
)
|
||||
command = "serve" if not serve_grpc else "serve-grpc"
|
||||
group = cog.optgroup.group(f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options", help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",)
|
||||
|
||||
def decorator(f: t.Callable[t.Concatenate[int, str | None, P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands[command]
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
serve_options = [p for p in serve_command.params[1 :-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
attrs.pop("param_type_name")
|
||||
# name is not a valid args
|
||||
attrs.pop("name")
|
||||
# type can be determine from default value
|
||||
attrs.pop("type")
|
||||
param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts"))
|
||||
f = cog.optgroup.option(*param_decls, **attrs)(f)
|
||||
return group(f)
|
||||
return decorator
|
||||
def decorator(f: t.Callable[t.Concatenate[int, str | None, P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands[command]
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
serve_options = [p for p in serve_command.params[1:-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
attrs.pop("param_type_name")
|
||||
# name is not a valid args
|
||||
attrs.pop("name")
|
||||
# type can be determine from default value
|
||||
attrs.pop("type")
|
||||
param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts"))
|
||||
f = cog.optgroup.option(*param_decls, **attrs)(f)
|
||||
return group(f)
|
||||
|
||||
return decorator
|
||||
|
||||
_http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True)
|
||||
|
||||
def cli_option(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
"""General ``@click.option`` with some sauce.
|
||||
"""General ``@click.option`` with some sauce.
|
||||
|
||||
This decorator extends the default ``@click.option`` plus a factory option to use which type of option, for example: [click, click_option_group.optgroup]
|
||||
"""
|
||||
attrs.setdefault("help", "General option for OpenLLM CLI.")
|
||||
factory = attrs.pop("factory", click)
|
||||
def decorator(f: FC | None) -> FC: return t.cast(FC, factory.option(*param_decls, **attrs)(f) if f is not None else factory.option(*param_decls, **attrs))
|
||||
return decorator
|
||||
This decorator extends the default ``@click.option`` plus a factory option to use which type of option, for example: [click, click_option_group.optgroup]
|
||||
"""
|
||||
attrs.setdefault("help", "General option for OpenLLM CLI.")
|
||||
factory = attrs.pop("factory", click)
|
||||
|
||||
def decorator(f: FC | None) -> FC:
|
||||
return t.cast(FC, factory.option(*param_decls, **attrs)(f) if f is not None else factory.option(*param_decls, **attrs))
|
||||
|
||||
return decorator
|
||||
|
||||
def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = "pretty", **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
output = ["json", "pretty", "porcelain"]
|
||||
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]: return [CompletionItem(it) for it in output]
|
||||
return cli_option("-o", "--output", "output", type=click.Choice(output), default=default_value, help="Showing output type.", show_default=True,
|
||||
envvar="OPENLLM_OUTPUT", show_envvar=True, shell_complete=complete_output_var, **attrs)(f)
|
||||
def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--fast/--no-fast", show_default=True, default=False, envvar="OPENLLM_USE_LOCAL_LATEST", show_envvar=True,
|
||||
help="""Whether to skip checking if models is already in store.
|
||||
output = ["json", "pretty", "porcelain"]
|
||||
|
||||
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]:
|
||||
return [CompletionItem(it) for it in output]
|
||||
|
||||
return cli_option("-o", "--output", "output", type=click.Choice(output), default=default_value, help="Showing output type.", show_default=True, envvar="OPENLLM_OUTPUT", show_envvar=True, shell_complete=complete_output_var, **attrs)(f)
|
||||
|
||||
def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--fast/--no-fast", show_default=True, default=False, envvar="OPENLLM_USE_LOCAL_LATEST", show_envvar=True, help="""Whether to skip checking if models is already in store.
|
||||
|
||||
This is useful if you already downloaded or setup the model beforehand.
|
||||
""", **attrs)(f)
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
def model_id_option(f: _AnyCallable | None = None, *, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None,
|
||||
show_envvar=model_env is not None,
|
||||
help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f)
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-version", type=click.STRING, default=None,
|
||||
help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
|
||||
""", **attrs
|
||||
)(f)
|
||||
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, *, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f)
|
||||
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]:
|
||||
arg = click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in CONFIG_MAPPING]), required=required)
|
||||
return arg(f) if f is not None else arg
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--quantise", "--quantize", "quantize", type=click.Choice(["int8", "int4", "gptq"]), default=None,
|
||||
envvar=model_env.quantize if model_env is not None else None, show_envvar=model_env is not None,
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
arg = click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in CONFIG_MAPPING]), required=required)
|
||||
return arg(f) if f is not None else arg
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--quantise", "--quantize", "quantize", type=click.Choice(["int8", "int4", "gptq"]), default=None, envvar=model_env.quantize if model_env is not None else None, show_envvar=model_env is not None, help="""Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
@@ -388,11 +375,16 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model
|
||||
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
**Note** that the model can also be served with quantized weights.
|
||||
""" + ("""
|
||||
**Note** that this will set the mode for serving within deployment.""" if build else "") + """
|
||||
**Note** that quantization are currently only available in *PyTorch* models.""", **attrs)(f)
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--workers-per-resource", default=None, callback=workers_per_resource_callback, type=str, required=False,
|
||||
help="""Number of workers per resource assigned.
|
||||
""" + (
|
||||
"""
|
||||
**Note** that this will set the mode for serving within deployment.""" if build else ""
|
||||
) + """
|
||||
**Note** that quantization are currently only available in *PyTorch* models.""", **attrs
|
||||
)(f)
|
||||
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--workers-per-resource", default=None, callback=workers_per_resource_callback, type=str, required=False, help="""Number of workers per resource assigned.
|
||||
|
||||
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -402,16 +394,22 @@ def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool =
|
||||
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
|
||||
- ``conserved``: This will determine the number of available GPU resources, and only assign one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
|
||||
""" + ("""\n
|
||||
""" + (
|
||||
"""\n
|
||||
**Note**: The workers value passed into 'build' will determine how the LLM can
|
||||
be provisioned in Kubernetes as well as in standalone container. This will
|
||||
ensure it has the same effect with 'openllm start --workers ...'""" if build else ""), **attrs)(f)
|
||||
def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--bettertransformer", is_flag=True, default=None, envvar=model_env.bettertransformer if model_env is not None else None, show_envvar=model_env is not None,
|
||||
help="Apply FasterTransformer wrapper to serve model. This will applies during serving time." if not build else "Set default environment variable whether to serve this model with FasterTransformer in build time.",
|
||||
**attrs)(f)
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--serialisation", "--serialization", "serialisation_format", type=click.Choice(["safetensors", "legacy"]),
|
||||
default="safetensors", show_default=True, show_envvar=True, envvar="OPENLLM_SERIALIZATION",
|
||||
help="""Serialisation format for save/load LLM.
|
||||
ensure it has the same effect with 'openllm start --workers ...'""" if build else ""
|
||||
), **attrs
|
||||
)(f)
|
||||
|
||||
def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--bettertransformer", is_flag=True, default=None, envvar=model_env.bettertransformer if model_env is not None else None, show_envvar=model_env is not None, help="Apply FasterTransformer wrapper to serve model. This will applies during serving time." if not build else "Set default environment variable whether to serve this model with FasterTransformer in build time.", **attrs
|
||||
)(f)
|
||||
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--serialisation", "--serialization", "serialisation_format", type=click.Choice(["safetensors", "legacy"]), default="safetensors", show_default=True, show_envvar=True, envvar="OPENLLM_SERIALIZATION", help="""Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
@@ -429,29 +427,34 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
|
||||
**Note** that GGML format is working in progress.
|
||||
""", **attrs
|
||||
)(f)
|
||||
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--container-registry", "container_registry", type=str, default="ecr", show_default=True, show_envvar=True, envvar="OPENLLM_CONTAINER_REGISTRY",
|
||||
callback=container_registry_callback,
|
||||
help="""The default container registry to get the base image for building BentoLLM.
|
||||
)(f)
|
||||
|
||||
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--container-registry", "container_registry", type=str, default="ecr", show_default=True, show_envvar=True, envvar="OPENLLM_CONTAINER_REGISTRY", callback=container_registry_callback, help="""The default container registry to get the base image for building BentoLLM.
|
||||
|
||||
Currently, it supports 'ecr', 'ghcr.io', 'docker.io'
|
||||
|
||||
\b
|
||||
**Note** that in order to build the base image, you will need a GPUs to compile custom kernel. See ``openllm ext build-base-container`` for more information.
|
||||
""")(f)
|
||||
"""
|
||||
)(f)
|
||||
|
||||
_wpr_strategies = {"round_robin", "conserved"}
|
||||
|
||||
def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
value = inflection.underscore(value)
|
||||
if value in _wpr_strategies: return value
|
||||
if value is None: return value
|
||||
value = inflection.underscore(value)
|
||||
if value in _wpr_strategies: return value
|
||||
else:
|
||||
try:
|
||||
float(value) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
|
||||
else:
|
||||
try: float(value) # type: ignore[arg-type]
|
||||
except ValueError: raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
|
||||
else: return value
|
||||
return value
|
||||
|
||||
def container_registry_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
if value not in bundle.supported_registries: raise click.BadParameter(f"Value must be one of {bundle.supported_registries}", ctx, param)
|
||||
return value
|
||||
if value is None: return value
|
||||
if value not in bundle.supported_registries: raise click.BadParameter(f"Value must be one of {bundle.supported_registries}", ctx, param)
|
||||
return value
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""OpenLLM CLI Extension.
|
||||
|
||||
The following directory contains all possible extensions for OpenLLM CLI
|
||||
|
||||
@@ -24,12 +24,11 @@ from .. import termui
|
||||
from .._factory import machine_option
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm.bundle.oci import LiteralContainerRegistry
|
||||
from openllm.bundle.oci import LiteralContainerVersionStrategy
|
||||
from openllm.bundle.oci import LiteralContainerRegistry
|
||||
from openllm.bundle.oci import LiteralContainerVersionStrategy
|
||||
|
||||
@click.command("build_base_container",
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help="""Base image builder for BentoLLM.
|
||||
@click.command(
|
||||
"build_base_container", context_settings=termui.CONTEXT_SETTINGS, help="""Base image builder for BentoLLM.
|
||||
|
||||
By default, the base image will include custom kernels (PagedAttention via vllm, FlashAttention-v2, etc.) built with CUDA 11.8, Python 3.9 on Ubuntu22.04.
|
||||
|
||||
@@ -40,12 +39,13 @@ if t.TYPE_CHECKING:
|
||||
This command is only useful for debugging and for building custom base image for extending BentoML with custom base images and custom kernels.
|
||||
|
||||
Note that we already release images on our CI to ECR and GHCR, so you don't need to build it yourself.
|
||||
""")
|
||||
"""
|
||||
)
|
||||
@click.option("--registry", multiple=True, type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)), help="Target registry to create image tag on.", default=None)
|
||||
@click.option("--version-strategy", type=click.Choice(["release", "latest", "nightly"]), default="nightly", help="Version strategy to use for tagging the image.")
|
||||
@click.option("--push/--no-push", help="Whether to push to remote repository", is_flag=True, default=False)
|
||||
@machine_option
|
||||
def cli(registry: tuple[LiteralContainerRegistry, ...] | None, version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]:
|
||||
mapping = openllm.bundle.build_container(registry, version_strategy, push, machine)
|
||||
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
return mapping
|
||||
mapping = openllm.bundle.build_container(registry, version_strategy, push, machine)
|
||||
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
return mapping
|
||||
|
||||
@@ -28,7 +28,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from .. import termui
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
@click.command("dive_bentos", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument("bento", type=str)
|
||||
@@ -36,12 +36,15 @@ if t.TYPE_CHECKING:
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
try: bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound: ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
|
||||
if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle": ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
|
||||
if machine: return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
if psutil.WINDOWS: subprocess.check_call([shutil.which("dir") or "dir"], cwd=bentomodel.path)
|
||||
else:subprocess.check_call([shutil.which("tree") or "tree"], cwd=bentomodel.path)
|
||||
ctx.exit(0)
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
|
||||
if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle":
|
||||
ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
|
||||
if machine: return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
if psutil.WINDOWS: subprocess.check_call([shutil.which("dir") or "dir"], cwd=bentomodel.path)
|
||||
else: subprocess.check_call([shutil.which("tree") or "tree"], cwd=bentomodel.path)
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -29,28 +29,28 @@ from .. import termui
|
||||
from ...utils import bentoml_cattr
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
@click.command("get_containerfile", context_settings=termui.CONTEXT_SETTINGS, help="Return Containerfile of any given Bento.")
|
||||
@click.argument("bento", type=str)
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str:
|
||||
try: bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound: ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
|
||||
# The logic below are similar to bentoml._internal.container.construct_containerfile
|
||||
with open(bentomodel.path_of("bento.yaml"), "r") as f:
|
||||
options = BentoInfo.from_yaml_file(f)
|
||||
# NOTE: dockerfile_template is already included in the
|
||||
# Dockerfile inside bento, and it is not relevant to
|
||||
# construct_containerfile. Hence it is safe to set it to None here.
|
||||
# See https://github.com/bentoml/BentoML/issues/3399.
|
||||
docker_attrs = bentoml_cattr.unstructure(options.docker)
|
||||
# NOTE: if users specify a dockerfile_template, we will
|
||||
# save it to /env/docker/Dockerfile.template. This is necessary
|
||||
# for the reconstruction of the Dockerfile.
|
||||
if "dockerfile_template" in docker_attrs and docker_attrs["dockerfile_template"] is not None: docker_attrs["dockerfile_template"] = "env/docker/Dockerfile.template"
|
||||
termui.echo(generate_containerfile(
|
||||
docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True,
|
||||
), fg="white")
|
||||
return bentomodel.path
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
|
||||
# The logic below are similar to bentoml._internal.container.construct_containerfile
|
||||
with open(bentomodel.path_of("bento.yaml"), "r") as f:
|
||||
options = BentoInfo.from_yaml_file(f)
|
||||
# NOTE: dockerfile_template is already included in the
|
||||
# Dockerfile inside bento, and it is not relevant to
|
||||
# construct_containerfile. Hence it is safe to set it to None here.
|
||||
# See https://github.com/bentoml/BentoML/issues/3399.
|
||||
docker_attrs = bentoml_cattr.unstructure(options.docker)
|
||||
# NOTE: if users specify a dockerfile_template, we will
|
||||
# save it to /env/docker/Dockerfile.template. This is necessary
|
||||
# for the reconstruction of the Dockerfile.
|
||||
if "dockerfile_template" in docker_attrs and docker_attrs["dockerfile_template"] is not None: docker_attrs["dockerfile_template"] = "env/docker/Dockerfile.template"
|
||||
termui.echo(generate_containerfile(docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True,), fg="white")
|
||||
return bentomodel.path
|
||||
|
||||
@@ -26,7 +26,7 @@ from .. import termui
|
||||
from ..._prompt import process_prompt
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..entrypoint import LiteralOutput
|
||||
from ..entrypoint import LiteralOutput
|
||||
|
||||
@click.command("get_prompt", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))
|
||||
@@ -37,27 +37,29 @@ if t.TYPE_CHECKING:
|
||||
@click.option("--opt", help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]")
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, output: LiteralOutput, machine: bool, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
|
||||
"""Get the default prompt used by OpenLLM."""
|
||||
module = openllm.utils.EnvVarMixin(model_name).module
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
try:
|
||||
template = getattr(module, "DEFAULT_PROMPT_TEMPLATE", None)
|
||||
prompt_mapping = getattr(module, "PROMPT_MAPPING", None)
|
||||
if template is None: raise click.BadArgumentUsage(f"model {model_name} does not have a default prompt template") from None
|
||||
if callable(template):
|
||||
if format is None:
|
||||
if not hasattr(module, "PROMPT_MAPPING") or module.PROMPT_MAPPING is None: raise RuntimeError("Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.")
|
||||
raise click.BadOptionUsage("format", f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
|
||||
if prompt_mapping is None: raise click.BadArgumentUsage(f"Failed to fine prompt mapping while the default prompt for {model_name} is a callable.") from None
|
||||
if format not in prompt_mapping: raise click.BadOptionUsage("format", f"Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})")
|
||||
_prompt_template = template(format)
|
||||
else: _prompt_template = template
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
if machine: return repr(fully_formatted)
|
||||
elif output == "porcelain": termui.echo(repr(fully_formatted), fg="white")
|
||||
elif output == "json": termui.echo(orjson.dumps({"prompt": fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
else:
|
||||
termui.echo(f"== Prompt for {model_name} ==\n", fg="magenta")
|
||||
termui.echo(fully_formatted, fg="white")
|
||||
except AttributeError: raise click.ClickException(f"Failed to determine a default prompt template for {model_name}.") from None
|
||||
ctx.exit(0)
|
||||
"""Get the default prompt used by OpenLLM."""
|
||||
module = openllm.utils.EnvVarMixin(model_name).module
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
try:
|
||||
template = getattr(module, "DEFAULT_PROMPT_TEMPLATE", None)
|
||||
prompt_mapping = getattr(module, "PROMPT_MAPPING", None)
|
||||
if template is None: raise click.BadArgumentUsage(f"model {model_name} does not have a default prompt template") from None
|
||||
if callable(template):
|
||||
if format is None:
|
||||
if not hasattr(module, "PROMPT_MAPPING") or module.PROMPT_MAPPING is None: raise RuntimeError("Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.")
|
||||
raise click.BadOptionUsage("format", f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
|
||||
if prompt_mapping is None: raise click.BadArgumentUsage(f"Failed to fine prompt mapping while the default prompt for {model_name} is a callable.") from None
|
||||
if format not in prompt_mapping: raise click.BadOptionUsage("format", f"Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})")
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
if machine: return repr(fully_formatted)
|
||||
elif output == "porcelain": termui.echo(repr(fully_formatted), fg="white")
|
||||
elif output == "json": termui.echo(orjson.dumps({"prompt": fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
else:
|
||||
termui.echo(f"== Prompt for {model_name} ==\n", fg="magenta")
|
||||
termui.echo(fully_formatted, fg="white")
|
||||
except AttributeError:
|
||||
raise click.ClickException(f"Failed to determine a default prompt template for {model_name}.") from None
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -26,9 +26,9 @@ from .. import termui
|
||||
@click.command("list_bentos", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context) -> None:
|
||||
"""List available bentos built by OpenLLM."""
|
||||
_local_bentos = {str(i.tag): i.info.labels["start_name"] for i in bentoml.list() if "start_name" in i.info.labels}
|
||||
mapping = {k: [tag for tag, name in _local_bentos.items() if name == k] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())}
|
||||
mapping = {k: v for k, v in mapping.items() if v}
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
ctx.exit(0)
|
||||
"""List available bentos built by OpenLLM."""
|
||||
_local_bentos = {str(i.tag): i.info.labels["start_name"] for i in bentoml.list() if "start_name" in i.info.labels}
|
||||
mapping = {k: [tag for tag, name in _local_bentos.items() if name == k] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())}
|
||||
mapping = {k: v for k, v in mapping.items() if v}
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -26,16 +26,16 @@ from .. import termui
|
||||
from .._factory import model_name_argument
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..._types import DictStrAny
|
||||
from ..._types import DictStrAny
|
||||
|
||||
@click.command("list_models", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False)
|
||||
def cli(model_name: str | None) -> DictStrAny:
|
||||
"""This is equivalent to openllm models --show-available less the nice table."""
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {k: [i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k] for k in models}
|
||||
if model_name is not None: ids_in_local_store = {k: [i for i in v if "model_name" in i.info.labels and i.info.labels["model_name"] == inflection.dasherize(model_name)] for k,v in ids_in_local_store.items()}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()}
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
return local_models
|
||||
"""This is equivalent to openllm models --show-available less the nice table."""
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {k: [i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k] for k in models}
|
||||
if model_name is not None: ids_in_local_store = {k: [i for i in v if "model_name" in i.info.labels and i.info.labels["model_name"] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items()}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()}
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
return local_models
|
||||
|
||||
@@ -23,13 +23,11 @@ from ..utils import get_debug_mode
|
||||
from ..utils import get_quiet_mode
|
||||
|
||||
def echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.Any) -> None:
|
||||
attrs["fg"], call = fg if not get_debug_mode() else None, click.echo if not _with_style else click.secho
|
||||
if not get_quiet_mode(): call(text, **attrs)
|
||||
|
||||
attrs["fg"], call = fg if not get_debug_mode() else None, click.echo if not _with_style else click.secho
|
||||
if not get_quiet_mode(): call(text, **attrs)
|
||||
|
||||
COLUMNS = int(os.getenv("COLUMNS", str(120)))
|
||||
|
||||
CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"], "max_content_width": COLUMNS, "token_normalize_func": inflection.underscore}
|
||||
|
||||
|
||||
__all__ = ["echo", "COLUMNS", "CONTEXT_SETTINGS"]
|
||||
|
||||
@@ -24,23 +24,25 @@ import importlib
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"runtimes.grpc": ["AsyncGrpcClient", "GrpcClient"],
|
||||
"runtimes.http": ["AsyncHTTPClient", "HTTPClient"],
|
||||
}
|
||||
_import_structure: dict[str, list[str]] = {"runtimes.grpc": ["AsyncGrpcClient", "GrpcClient"], "runtimes.http": ["AsyncHTTPClient", "HTTPClient"],}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_client import AsyncGrpcClient as AsyncGrpcClient
|
||||
from openllm_client import AsyncHTTPClient as AsyncHTTPClient
|
||||
from openllm_client import GrpcClient as GrpcClient
|
||||
from openllm_client import HTTPClient as HTTPClient
|
||||
from openllm_client import AsyncGrpcClient as AsyncGrpcClient
|
||||
from openllm_client import AsyncHTTPClient as AsyncHTTPClient
|
||||
from openllm_client import GrpcClient as GrpcClient
|
||||
from openllm_client import HTTPClient as HTTPClient
|
||||
|
||||
_module = "openllm_client"
|
||||
|
||||
__all__ = list(itertools.chain.from_iterable(_import_structure.values()))
|
||||
def __dir__() -> list[str]: return sorted(__all__)
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name in _import_structure: return importlib.import_module(f".{name}", _module)
|
||||
try: module = next(module for module, attrs in _import_structure.items() if name in attrs)
|
||||
except StopIteration: raise AttributeError(f"module {_module} has no attribute {name}") from None
|
||||
return getattr(importlib.import_module(f".{module}", _module), name)
|
||||
if name in _import_structure: return importlib.import_module(f".{name}", _module)
|
||||
try:
|
||||
module = next(module for module, attrs in _import_structure.items() if name in attrs)
|
||||
except StopIteration:
|
||||
raise AttributeError(f"module {_module} has no attribute {name}") from None
|
||||
return getattr(importlib.import_module(f".{module}", _module), name)
|
||||
|
||||
@@ -17,30 +17,25 @@ from __future__ import annotations
|
||||
import bentoml
|
||||
|
||||
class OpenLLMException(bentoml.exceptions.BentoMLException):
|
||||
"""Base class for all OpenLLM exceptions. This extends BentoMLException."""
|
||||
|
||||
"""Base class for all OpenLLM exceptions. This extends BentoMLException."""
|
||||
|
||||
class GpuNotAvailableError(OpenLLMException):
|
||||
"""Raised when there is no GPU available in given system."""
|
||||
|
||||
"""Raised when there is no GPU available in given system."""
|
||||
|
||||
class ValidationError(OpenLLMException):
|
||||
"""Raised when a validation fails."""
|
||||
|
||||
"""Raised when a validation fails."""
|
||||
|
||||
class ForbiddenAttributeError(OpenLLMException):
|
||||
"""Raised when using an _internal field."""
|
||||
|
||||
"""Raised when using an _internal field."""
|
||||
|
||||
class MissingAnnotationAttributeError(OpenLLMException):
|
||||
"""Raised when a field under openllm.LLMConfig is missing annotations."""
|
||||
|
||||
"""Raised when a field under openllm.LLMConfig is missing annotations."""
|
||||
|
||||
class MissingDependencyError(BaseException):
|
||||
"""Raised when a dependency is missing."""
|
||||
"""Raised when a dependency is missing."""
|
||||
|
||||
class Error(BaseException):
|
||||
"""To be used instead of naked raise."""
|
||||
"""To be used instead of naked raise."""
|
||||
|
||||
class FineTuneStrategyNotSupportedError(OpenLLMException):
|
||||
"""Raised when a fine-tune strategy is not supported for given LLM."""
|
||||
"""Raised when a fine-tune strategy is not supported for given LLM."""
|
||||
|
||||
@@ -23,20 +23,21 @@ _MODELS: set[str] = {'auto', 'baichuan', 'chatglm', 'dolly_v2', 'falcon', 'flan_
|
||||
# fmt: on
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# fmt: off
|
||||
# update-models-import.py: start types
|
||||
from . import auto as auto
|
||||
from . import baichuan as baichuan
|
||||
from . import chatglm as chatglm
|
||||
from . import dolly_v2 as dolly_v2
|
||||
from . import falcon as falcon
|
||||
from . import flan_t5 as flan_t5
|
||||
from . import gpt_neox as gpt_neox
|
||||
from . import llama as llama
|
||||
from . import mpt as mpt
|
||||
from . import opt as opt
|
||||
from . import stablelm as stablelm
|
||||
from . import starcoder as starcoder
|
||||
# update-models-import.py: stop types
|
||||
# fmt: on
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS}, module_spec=__spec__)
|
||||
# fmt: off
|
||||
# update-models-import.py: start types
|
||||
from . import auto as auto
|
||||
from . import baichuan as baichuan
|
||||
from . import chatglm as chatglm
|
||||
from . import dolly_v2 as dolly_v2
|
||||
from . import falcon as falcon
|
||||
from . import flan_t5 as flan_t5
|
||||
from . import gpt_neox as gpt_neox
|
||||
from . import llama as llama
|
||||
from . import mpt as mpt
|
||||
from . import opt as opt
|
||||
from . import stablelm as stablelm
|
||||
from . import starcoder as starcoder
|
||||
# update-models-import.py: stop types
|
||||
# fmt: on
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS}, module_spec=__spec__)
|
||||
|
||||
@@ -20,51 +20,63 @@ from ...utils import is_flax_available
|
||||
from ...utils import is_tf_available
|
||||
from ...utils import is_torch_available
|
||||
from ...utils import is_vllm_available
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"],
|
||||
"modeling_auto": ["MODEL_MAPPING_NAMES"],
|
||||
"modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"],
|
||||
"modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"],
|
||||
"modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"],
|
||||
}
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"], "modeling_auto": ["MODEL_MAPPING_NAMES"], "modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"], "modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"], "modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"],}
|
||||
try:
|
||||
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError: pass
|
||||
else: _import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"])
|
||||
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"])
|
||||
try:
|
||||
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError: pass
|
||||
else: _import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
|
||||
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
|
||||
try:
|
||||
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError: pass
|
||||
else: _import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
|
||||
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
|
||||
try:
|
||||
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError: pass
|
||||
else: _import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
|
||||
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_auto import CONFIG_MAPPING as CONFIG_MAPPING
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
|
||||
from .configuration_auto import AutoConfig as AutoConfig
|
||||
from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
|
||||
from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
|
||||
from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
|
||||
from .modeling_vllm_auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
try:
|
||||
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError: pass
|
||||
else: from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM
|
||||
try:
|
||||
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError: pass
|
||||
else: from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM
|
||||
try:
|
||||
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError: pass
|
||||
else: from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
|
||||
try:
|
||||
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError: pass
|
||||
else: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_auto import CONFIG_MAPPING as CONFIG_MAPPING
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
|
||||
from .configuration_auto import AutoConfig as AutoConfig
|
||||
from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
|
||||
from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
|
||||
from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
|
||||
from .modeling_vllm_auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
try:
|
||||
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM
|
||||
try:
|
||||
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM
|
||||
try:
|
||||
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
|
||||
try:
|
||||
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -18,77 +18,82 @@ import inflection
|
||||
import openllm
|
||||
from ...utils import ReprMixin
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
from collections import _odict_items, _odict_keys, _odict_values
|
||||
ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]]
|
||||
ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]]
|
||||
ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]]
|
||||
ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]]
|
||||
import types
|
||||
from collections import _odict_items, _odict_keys, _odict_values
|
||||
ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]]
|
||||
ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]]
|
||||
ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]]
|
||||
ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]]
|
||||
else:
|
||||
ConfigKeysView = ConfigValuesView = ConfigItemsView = t.Any
|
||||
ConfigOrderedDict = OrderedDict
|
||||
ConfigKeysView = ConfigValuesView = ConfigItemsView = t.Any
|
||||
ConfigOrderedDict = OrderedDict
|
||||
# NOTE: This is the entrypoint when adding new model config
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("chatglm", "ChatGLMConfig"),
|
||||
("dolly_v2", "DollyV2Config"),
|
||||
("falcon", "FalconConfig"),
|
||||
("flan_t5", "FlanT5Config"),
|
||||
("gpt_neox", "GPTNeoXConfig"),
|
||||
("llama", "LlamaConfig"),
|
||||
("mpt", "MPTConfig"),
|
||||
("opt", "OPTConfig"),
|
||||
("stablelm", "StableLMConfig"),
|
||||
("starcoder", "StarCoderConfig"),
|
||||
("baichuan", "BaichuanConfig"),
|
||||
]
|
||||
)
|
||||
CONFIG_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLMConfig"), ("dolly_v2", "DollyV2Config"), ("falcon", "FalconConfig"), ("flan_t5", "FlanT5Config"), ("gpt_neox", "GPTNeoXConfig"), ("llama", "LlamaConfig"), ("mpt", "MPTConfig"), ("opt", "OPTConfig"), ("stablelm", "StableLMConfig"), ("starcoder", "StarCoderConfig"), ("baichuan", "BaichuanConfig"),])
|
||||
|
||||
class _LazyConfigMapping(ConfigOrderedDict, ReprMixin):
|
||||
def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]):
|
||||
self._mapping = mapping
|
||||
self._extra_content: dict[str, t.Any] = {}
|
||||
self._modules: dict[str, types.ModuleType] = {}
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
if key in self._extra_content: return self._extra_content[key]
|
||||
if key not in self._mapping:
|
||||
if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key))
|
||||
raise KeyError(key)
|
||||
value, module_name = self._mapping[key], inflection.underscore(key)
|
||||
if module_name not in self._modules: self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module
|
||||
if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value)
|
||||
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level.
|
||||
return getattr(openllm, value)
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]: return set(self._mapping.keys())
|
||||
def __repr__(self) -> str: return ReprMixin.__repr__(self)
|
||||
def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]: yield from self._mapping.items()
|
||||
def keys(self): return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
def values(self): return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()))
|
||||
def items(self): return t.cast(ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()))
|
||||
def __iter__(self): return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
def __contains__(self, item: t.Any): return item in self._mapping or item in self._extra_content
|
||||
def register(self, key: str, value: t.Any):
|
||||
if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
|
||||
self._extra_content[key] = value
|
||||
def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]):
|
||||
self._mapping = mapping
|
||||
self._extra_content: dict[str, t.Any] = {}
|
||||
self._modules: dict[str, types.ModuleType] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
if key in self._extra_content: return self._extra_content[key]
|
||||
if key not in self._mapping:
|
||||
if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key))
|
||||
raise KeyError(key)
|
||||
value, module_name = self._mapping[key], inflection.underscore(key)
|
||||
if module_name not in self._modules: self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module
|
||||
if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value)
|
||||
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level.
|
||||
return getattr(openllm, value)
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
return set(self._mapping.keys())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ReprMixin.__repr__(self)
|
||||
|
||||
def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]:
|
||||
yield from self._mapping.items()
|
||||
|
||||
def keys(self):
|
||||
return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
|
||||
def values(self):
|
||||
return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()))
|
||||
|
||||
def items(self):
|
||||
return t.cast(ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
|
||||
def __contains__(self, item: t.Any):
|
||||
return item in self._mapping or item in self._extra_content
|
||||
|
||||
def register(self, key: str, value: t.Any):
|
||||
if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
|
||||
self._extra_content[key] = value
|
||||
|
||||
CONFIG_MAPPING: dict[str, type[openllm.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
# The below handle special alias when we call underscore to the name directly
|
||||
# without processing camelcase first.
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {
|
||||
"chat_glm": "chatglm",
|
||||
"stable_lm": "stablelm",
|
||||
"star_coder": "starcoder",
|
||||
"gpt_neo_x": "gpt_neox",
|
||||
}
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {"chat_glm": "chatglm", "stable_lm": "stablelm", "star_coder": "starcoder", "gpt_neo_x": "gpt_neox",}
|
||||
|
||||
class AutoConfig:
|
||||
def __init__(self, *_: t.Any, **__: t.Any): raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.")
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig:
|
||||
model_name = inflection.underscore(model_name)
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
@classmethod
|
||||
def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]:
|
||||
model_name = inflection.underscore(name)
|
||||
if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name]
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name]
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
def __init__(self, *_: t.Any, **__: t.Any):
|
||||
raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.")
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig:
|
||||
model_name = inflection.underscore(model_name)
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
|
||||
@classmethod
|
||||
def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]:
|
||||
model_name = inflection.underscore(name)
|
||||
if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name]
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name]
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
|
||||
@@ -21,35 +21,40 @@ import inflection
|
||||
import openllm
|
||||
from ...utils import ReprMixin
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
from ..._llm import LLMRunner
|
||||
from collections import _odict_items, _odict_keys, _odict_values
|
||||
ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
import types
|
||||
from ..._llm import LLMRunner
|
||||
from collections import _odict_items, _odict_keys, _odict_values
|
||||
ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
else:
|
||||
ConfigModelKeysView = ConfigModelValuesView = ConfigModelItemsView = t.Any
|
||||
ConfigModelOrderedDict = OrderedDict
|
||||
ConfigModelKeysView = ConfigModelValuesView = ConfigModelItemsView = t.Any
|
||||
ConfigModelOrderedDict = OrderedDict
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAutoLLMClass:
|
||||
_model_mapping: _LazyAutoMapping
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any): raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.")
|
||||
@classmethod
|
||||
def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]:
|
||||
"""The lower level API for creating a LLM instance.
|
||||
_model_mapping: _LazyAutoMapping
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.")
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]:
|
||||
"""The lower level API for creating a LLM instance.
|
||||
|
||||
```python
|
||||
>>> import openllm
|
||||
>>> llm = openllm.AutoLLM.for_model("flan-t5")
|
||||
```
|
||||
"""
|
||||
llm = cls.infer_class_from_name(model).from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs)
|
||||
if ensure_available: llm.ensure_model_id_exists()
|
||||
return llm
|
||||
@classmethod
|
||||
def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]:
|
||||
"""Create a LLM Runner for the given model name.
|
||||
llm = cls.infer_class_from_name(model).from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs)
|
||||
if ensure_available: llm.ensure_model_id_exists()
|
||||
return llm
|
||||
|
||||
@classmethod
|
||||
def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]:
|
||||
"""Create a LLM Runner for the given model name.
|
||||
|
||||
Args:
|
||||
model: The model name to instantiate.
|
||||
@@ -59,76 +64,107 @@ class BaseAutoLLMClass:
|
||||
Returns:
|
||||
A LLM instance.
|
||||
"""
|
||||
runner_kwargs_name = set(inspect.signature(openllm.LLM[t.Any, t.Any].to_runner).parameters)
|
||||
runner_attrs = {k: v for k, v in attrs.items() if k in runner_kwargs_name}
|
||||
for k in runner_attrs: del attrs[k]
|
||||
return cls.for_model(model, model_id=model_id, **attrs).to_runner(**runner_attrs)
|
||||
@classmethod
|
||||
def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]):
|
||||
"""Register a new model for this class.
|
||||
runner_kwargs_name = set(inspect.signature(openllm.LLM[t.Any, t.Any].to_runner).parameters)
|
||||
runner_attrs = {k: v for k, v in attrs.items() if k in runner_kwargs_name}
|
||||
for k in runner_attrs:
|
||||
del attrs[k]
|
||||
return cls.for_model(model, model_id=model_id, **attrs).to_runner(**runner_attrs)
|
||||
|
||||
@classmethod
|
||||
def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]):
|
||||
"""Register a new model for this class.
|
||||
|
||||
Args:
|
||||
config_class: The configuration corresponding to the model to register.
|
||||
llm_class: The runnable to register.
|
||||
"""
|
||||
if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class: raise ValueError(f"The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!")
|
||||
cls._model_mapping.register(config_class, llm_class)
|
||||
@classmethod
|
||||
def infer_class_from_name(cls, name: str) -> type[openllm.LLM[t.Any, t.Any]]:
|
||||
config_class = openllm.AutoConfig.infer_class_from_name(name)
|
||||
if config_class in cls._model_mapping: return cls._model_mapping[config_class]
|
||||
raise ValueError(f"Unrecognized configuration class ({config_class}) for {name}. Model name should be one of {', '.join(openllm.CONFIG_MAPPING.keys())} (Registered configuration class: {', '.join([i.__name__ for i in cls._model_mapping.keys()])}).")
|
||||
if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class:
|
||||
raise ValueError(f"The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!")
|
||||
cls._model_mapping.register(config_class, llm_class)
|
||||
|
||||
@classmethod
|
||||
def infer_class_from_name(cls, name: str) -> type[openllm.LLM[t.Any, t.Any]]:
|
||||
config_class = openllm.AutoConfig.infer_class_from_name(name)
|
||||
if config_class in cls._model_mapping: return cls._model_mapping[config_class]
|
||||
raise ValueError(f"Unrecognized configuration class ({config_class}) for {name}. Model name should be one of {', '.join(openllm.CONFIG_MAPPING.keys())} (Registered configuration class: {', '.join([i.__name__ for i in cls._model_mapping.keys()])}).")
|
||||
|
||||
def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
|
||||
if attr is None: return
|
||||
if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr)
|
||||
if hasattr(module, attr): return getattr(module, attr)
|
||||
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the object at the top level.
|
||||
openllm_module = importlib.import_module("openllm")
|
||||
if module != openllm_module:
|
||||
try: return getattribute_from_module(openllm_module, attr)
|
||||
except ValueError: raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
|
||||
raise ValueError(f"Could not find {attr} in {openllm_module}!")
|
||||
if attr is None: return
|
||||
if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr)
|
||||
if hasattr(module, attr): return getattr(module, attr)
|
||||
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the object at the top level.
|
||||
openllm_module = importlib.import_module("openllm")
|
||||
if module != openllm_module:
|
||||
try:
|
||||
return getattribute_from_module(openllm_module, attr)
|
||||
except ValueError:
|
||||
raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
|
||||
raise ValueError(f"Could not find {attr} in {openllm_module}!")
|
||||
|
||||
class _LazyAutoMapping(ConfigModelOrderedDict, ReprMixin):
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
|
||||
|
||||
This OrderedDict values() and keys() returns the list instead, so you don't
|
||||
have to do list(mapping.values()) to get the list of values.
|
||||
"""
|
||||
def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]):
|
||||
self._config_mapping = config_mapping
|
||||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
||||
self._model_mapping = model_mapping
|
||||
self._extra_content: dict[t.Any, t.Any] = {}
|
||||
self._modules: dict[str, types.ModuleType] = {}
|
||||
def __len__(self): return len(set(self._config_mapping.keys()).intersection(self._model_mapping.keys())) + len(self._extra_content)
|
||||
def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]:
|
||||
if key in self._extra_content: return self._extra_content[key]
|
||||
model_type = self._reverse_config_mapping[key.__name__]
|
||||
if model_type in self._model_mapping: return self._load_attr_from_module(model_type, self._model_mapping[model_type])
|
||||
# Maybe there was several model types associated with this config.
|
||||
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
|
||||
for mtype in model_types:
|
||||
if mtype in self._model_mapping: return self._load_attr_from_module(mtype, self._model_mapping[mtype])
|
||||
raise KeyError(key)
|
||||
def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any:
|
||||
module_name = inflection.underscore(model_type)
|
||||
if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models")
|
||||
return getattribute_from_module(self._modules[module_name], attr)
|
||||
def keys(self): return t.cast(ConfigModelKeysView, [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys()))
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]: return set(self._config_mapping.keys())
|
||||
def __repr__(self) -> str: return ReprMixin.__repr__(self)
|
||||
def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]: yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping)
|
||||
def __bool__(self): return bool(self.keys())
|
||||
def values(self): return t.cast(ConfigModelValuesView, [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(self._extra_content.values()))
|
||||
def items(self): return t.cast(ConfigModelItemsView, [(self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items()))
|
||||
def __iter__(self): return iter(self.keys())
|
||||
def __contains__(self, item: t.Any):
|
||||
if item in self._extra_content: return True
|
||||
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: return False
|
||||
return self._reverse_config_mapping[item.__name__] in self._model_mapping
|
||||
def register(self, key: t.Any, value: t.Any):
|
||||
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
|
||||
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.")
|
||||
self._extra_content[key] = value
|
||||
def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]):
|
||||
self._config_mapping = config_mapping
|
||||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
||||
self._model_mapping = model_mapping
|
||||
self._extra_content: dict[t.Any, t.Any] = {}
|
||||
self._modules: dict[str, types.ModuleType] = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(set(self._config_mapping.keys()).intersection(self._model_mapping.keys())) + len(self._extra_content)
|
||||
|
||||
def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]:
|
||||
if key in self._extra_content: return self._extra_content[key]
|
||||
model_type = self._reverse_config_mapping[key.__name__]
|
||||
if model_type in self._model_mapping: return self._load_attr_from_module(model_type, self._model_mapping[model_type])
|
||||
# Maybe there was several model types associated with this config.
|
||||
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
|
||||
for mtype in model_types:
|
||||
if mtype in self._model_mapping: return self._load_attr_from_module(mtype, self._model_mapping[mtype])
|
||||
raise KeyError(key)
|
||||
|
||||
def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any:
|
||||
module_name = inflection.underscore(model_type)
|
||||
if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models")
|
||||
return getattribute_from_module(self._modules[module_name], attr)
|
||||
|
||||
def keys(self):
|
||||
return t.cast(ConfigModelKeysView, [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys()))
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
return set(self._config_mapping.keys())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ReprMixin.__repr__(self)
|
||||
|
||||
def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]:
|
||||
yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.keys())
|
||||
|
||||
def values(self):
|
||||
return t.cast(ConfigModelValuesView, [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(self._extra_content.values()))
|
||||
|
||||
def items(self):
|
||||
return t.cast(ConfigModelItemsView, [(self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items()))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.keys())
|
||||
|
||||
def __contains__(self, item: t.Any):
|
||||
if item in self._extra_content: return True
|
||||
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: return False
|
||||
return self._reverse_config_mapping[item.__name__] in self._model_mapping
|
||||
|
||||
def register(self, key: t.Any, value: t.Any):
|
||||
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
|
||||
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.")
|
||||
self._extra_content[key] = value
|
||||
|
||||
__all__ = ["BaseAutoLLMClass", "_LazyAutoMapping"]
|
||||
|
||||
@@ -17,21 +17,9 @@ from collections import OrderedDict
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("chatglm", "ChatGLM"),
|
||||
("dolly_v2", "DollyV2"),
|
||||
("falcon", "Falcon"),
|
||||
("flan_t5", "FlanT5"),
|
||||
("gpt_neox", "GPTNeoX"),
|
||||
("llama", "Llama"),
|
||||
("mpt", "MPT"),
|
||||
("opt", "OPT"),
|
||||
("stablelm", "StableLM"),
|
||||
("starcoder", "StarCoder"),
|
||||
("baichuan", "Baichuan"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLM"), ("dolly_v2", "DollyV2"), ("falcon", "Falcon"), ("flan_t5", "FlanT5"), ("gpt_neox", "GPTNeoX"), ("llama", "Llama"), ("mpt", "MPT"), ("opt", "OPT"), ("stablelm", "StableLM"), ("starcoder", "StarCoder"), ("baichuan", "Baichuan"),])
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
|
||||
class AutoLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_MAPPING
|
||||
_model_mapping = MODEL_MAPPING
|
||||
|
||||
@@ -16,12 +16,9 @@ from collections import OrderedDict
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
MODEL_FLAX_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("flan_t5", "FlaxFlanT5"),
|
||||
("opt", "FlaxOPT"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT"),])
|
||||
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
|
||||
|
||||
class AutoFlaxLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_FLAX_MAPPING
|
||||
_model_mapping = MODEL_FLAX_MAPPING
|
||||
|
||||
@@ -16,12 +16,9 @@ from collections import OrderedDict
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
MODEL_TF_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("flan_t5", "TFFlanT5"),
|
||||
("opt", "TFOPT"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5"), ("opt", "TFOPT"),])
|
||||
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
|
||||
|
||||
class AutoTFLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_TF_MAPPING
|
||||
_model_mapping = MODEL_TF_MAPPING
|
||||
|
||||
@@ -16,12 +16,9 @@ from collections import OrderedDict
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
MODEL_VLLM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("llama", "VLLMLlama"),
|
||||
("opt", "VLLMOPT")
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_VLLM_MAPPING_NAMES = OrderedDict([("llama", "VLLMLlama"), ("opt", "VLLMOPT")])
|
||||
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
|
||||
|
||||
class AutoVLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_VLLM_MAPPING
|
||||
_model_mapping = MODEL_VLLM_MAPPING
|
||||
|
||||
@@ -18,18 +18,24 @@ from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_cpm_kernels_available
|
||||
from ...utils import is_torch_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_baichuan": ["BaichuanConfig", "START_BAICHUAN_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_baichuan"] = ["Baichuan"]
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_baichuan"] = ["Baichuan"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_baichuan import START_BAICHUAN_COMMAND_DOCSTRING as START_BAICHUAN_COMMAND_DOCSTRING
|
||||
from .configuration_baichuan import BaichuanConfig as BaichuanConfig
|
||||
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_baichuan import START_BAICHUAN_COMMAND_DOCSTRING as START_BAICHUAN_COMMAND_DOCSTRING
|
||||
from .configuration_baichuan import BaichuanConfig as BaichuanConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_baichuan import Baichuan as Baichuan
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_baichuan import Baichuan as Baichuan
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class BaichuanConfig(openllm.LLMConfig):
|
||||
"""Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology.
|
||||
"""Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology.
|
||||
|
||||
Baichuan-7B is based on Transformer architecture,
|
||||
which contains 7 billion parameters and trained on approximately 1.2 trillion tokens.
|
||||
@@ -23,28 +24,16 @@ class BaichuanConfig(openllm.LLMConfig):
|
||||
and English benchmarks (C-Eval, MMLU, etc).
|
||||
Refer to [Baichuan-7B's GitHub page](https://github.com/baichuan-inc/Baichuan-7B) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"timeout": 3600000,
|
||||
"requires_gpu": True,
|
||||
"url": "https://github.com/baichuan-inc/Baichuan-7B",
|
||||
"requirements": ["cpm-kernels", "sentencepiece"],
|
||||
"architecture": "BaiChuanForCausalLM",
|
||||
"default_id": "baichuan-inc/baichuan-7b",
|
||||
"model_ids": [
|
||||
"baichuan-inc/baichuan-7b",
|
||||
"baichuan-inc/baichuan-13b-base",
|
||||
"baichuan-inc/baichuan-13b-chat",
|
||||
"fireballoon/baichuan-vicuna-chinese-7b",
|
||||
"fireballoon/baichuan-vicuna-7b",
|
||||
"hiyouga/baichuan-7b-sft",
|
||||
],
|
||||
}
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 2048
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
__config__ = {
|
||||
"name_type": "lowercase", "trust_remote_code": True, "timeout": 3600000, "requires_gpu": True, "url": "https://github.com/baichuan-inc/Baichuan-7B", "requirements": ["cpm-kernels", "sentencepiece"], "architecture": "BaiChuanForCausalLM", "default_id": "baichuan-inc/baichuan-7b",
|
||||
"model_ids": ["baichuan-inc/baichuan-7b", "baichuan-inc/baichuan-13b-base", "baichuan-inc/baichuan-13b-chat", "fireballoon/baichuan-vicuna-chinese-7b", "fireballoon/baichuan-vicuna-7b", "hiyouga/baichuan-7b-sft",],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 2048
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
START_BAICHUAN_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for Baichuan model.
|
||||
|
||||
|
||||
@@ -18,12 +18,18 @@ from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
class Baichuan(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
||||
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
||||
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
@@ -18,17 +18,23 @@ from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_cpm_kernels_available
|
||||
from ...utils import is_torch_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_chatglm": ["ChatGLMConfig", "START_CHATGLM_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_chatglm"] = ["ChatGLM"]
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_chatglm"] = ["ChatGLM"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_chatglm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_chatglm import START_CHATGLM_COMMAND_DOCSTRING as START_CHATGLM_COMMAND_DOCSTRING
|
||||
from .configuration_chatglm import ChatGLMConfig as ChatGLMConfig
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_chatglm import ChatGLM as ChatGLM
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_chatglm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_chatglm import START_CHATGLM_COMMAND_DOCSTRING as START_CHATGLM_COMMAND_DOCSTRING
|
||||
from .configuration_chatglm import ChatGLMConfig as ChatGLMConfig
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_chatglm import ChatGLM as ChatGLM
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class ChatGLMConfig(openllm.LLMConfig):
|
||||
"""ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
|
||||
"""ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
|
||||
|
||||
With the quantization technique, users can deploy locally on consumer-grade graphics cards
|
||||
(only 6GB of GPU memory is required at the INT4 quantization level).
|
||||
@@ -27,34 +28,20 @@ class ChatGLMConfig(openllm.LLMConfig):
|
||||
|
||||
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"timeout": 3600000,
|
||||
"requires_gpu": True,
|
||||
"url": "https://github.com/THUDM/ChatGLM-6B",
|
||||
"requirements": ["cpm-kernels", "sentencepiece"],
|
||||
"architecture": "ChatGLMForConditionalGeneration",
|
||||
"default_id": "thudm/chatglm-6b",
|
||||
"model_ids": [
|
||||
"thudm/chatglm-6b",
|
||||
"thudm/chatglm-6b-int8",
|
||||
"thudm/chatglm-6b-int4",
|
||||
"thudm/chatglm2-6b",
|
||||
"thudm/chatglm2-6b-int4",
|
||||
],
|
||||
}
|
||||
retain_history: bool = openllm.LLMConfig.Field(
|
||||
False,
|
||||
description="""Whether to retain history given to the model.
|
||||
If set to True, then the model will retain given history.""",
|
||||
)
|
||||
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 2048
|
||||
num_beams: int = 1
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
__config__ = {
|
||||
"name_type": "lowercase", "trust_remote_code": True, "timeout": 3600000, "requires_gpu": True, "url": "https://github.com/THUDM/ChatGLM-6B", "requirements": ["cpm-kernels", "sentencepiece"], "architecture": "ChatGLMForConditionalGeneration", "default_id": "thudm/chatglm-6b",
|
||||
"model_ids": ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4", "thudm/chatglm2-6b", "thudm/chatglm2-6b-int4",],
|
||||
}
|
||||
retain_history: bool = openllm.LLMConfig.Field(False, description="""Whether to retain history given to the model.
|
||||
If set to True, then the model will retain given history.""",)
|
||||
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 2048
|
||||
num_beams: int = 1
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
START_CHATGLM_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for ChatGLM model.
|
||||
|
||||
|
||||
@@ -17,38 +17,45 @@ import openllm
|
||||
from ..._llm import LLMEmbeddings
|
||||
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
|
||||
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
|
||||
|
||||
class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, num_beams: int | None = None, top_p: float | None = None, temperature: float | None = None, chat_history: list[str] | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
prompt_text = ""
|
||||
if use_default_prompt_template and chat_history is not None:
|
||||
for i, (old_query, response) in enumerate(chat_history): prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" # noqa: RUF001
|
||||
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" # noqa: RUF001
|
||||
else: prompt_text = prompt
|
||||
postprocess_generate_kwargs = {"chat_history": chat_history if chat_history is not None else None}
|
||||
# NOTE: The rest of attrs should be kwargs for GenerationConfig
|
||||
generate_kwargs = {"max_new_tokens": max_new_tokens, "num_beams": num_beams, "top_p": top_p, "temperature": temperature, **attrs}
|
||||
return prompt_text, generate_kwargs, postprocess_generate_kwargs
|
||||
def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any):
|
||||
generated, history = generation_result
|
||||
if self.config.retain_history:
|
||||
assert chat_history is not None, "'retain_history' is True while there is no history provided."
|
||||
chat_history.extend(history)
|
||||
return generated
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, str]]]:
|
||||
with torch.inference_mode():
|
||||
self.model.eval()
|
||||
# Only use half precision if the model is not yet quantized
|
||||
if self.config.use_half_precision: self.model.half()
|
||||
return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, output_hidden_states=True)
|
||||
data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0)
|
||||
embeddings.append(data.tolist())
|
||||
num_tokens += len(input_ids[0])
|
||||
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, num_beams: int | None = None, top_p: float | None = None, temperature: float | None = None, chat_history: list[str] | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
prompt_text = ""
|
||||
if use_default_prompt_template and chat_history is not None:
|
||||
for i, (old_query, response) in enumerate(chat_history):
|
||||
prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" # noqa: RUF001
|
||||
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" # noqa: RUF001
|
||||
else:
|
||||
prompt_text = prompt
|
||||
postprocess_generate_kwargs = {"chat_history": chat_history if chat_history is not None else None}
|
||||
# NOTE: The rest of attrs should be kwargs for GenerationConfig
|
||||
generate_kwargs = {"max_new_tokens": max_new_tokens, "num_beams": num_beams, "top_p": top_p, "temperature": temperature, **attrs}
|
||||
return prompt_text, generate_kwargs, postprocess_generate_kwargs
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any):
|
||||
generated, history = generation_result
|
||||
if self.config.retain_history:
|
||||
assert chat_history is not None, "'retain_history' is True while there is no history provided."
|
||||
chat_history.extend(history)
|
||||
return generated
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, str]]]:
|
||||
with torch.inference_mode():
|
||||
self.model.eval()
|
||||
# Only use half precision if the model is not yet quantized
|
||||
if self.config.use_half_precision: self.model.half()
|
||||
return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, output_hidden_states=True)
|
||||
data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0)
|
||||
embeddings.append(data.tolist())
|
||||
num_tokens += len(input_ids[0])
|
||||
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)
|
||||
|
||||
@@ -17,17 +17,23 @@ import typing as t
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_dolly_v2": ["DollyV2Config", "START_DOLLY_V2_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_dolly_v2"] = ["DollyV2"]
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_dolly_v2"] = ["DollyV2"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_dolly_v2 import START_DOLLY_V2_COMMAND_DOCSTRING as START_DOLLY_V2_COMMAND_DOCSTRING
|
||||
from .configuration_dolly_v2 import DollyV2Config as DollyV2Config
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_dolly_v2 import DollyV2 as DollyV2
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_dolly_v2 import START_DOLLY_V2_COMMAND_DOCSTRING as START_DOLLY_V2_COMMAND_DOCSTRING
|
||||
from .configuration_dolly_v2 import DollyV2Config as DollyV2Config
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_dolly_v2 import DollyV2 as DollyV2
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -15,8 +15,9 @@ from __future__ import annotations
|
||||
import typing as t
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class DollyV2Config(openllm.LLMConfig):
|
||||
"""Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.
|
||||
"""Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.
|
||||
|
||||
Based on pythia-12b, Dolly is trained on ~15k instruction/response fine tuning records databricks-dolly-15k
|
||||
generated by Databricks employees in capability domains from the InstructGPT paper, including brainstorming,
|
||||
@@ -27,22 +28,16 @@ class DollyV2Config(openllm.LLMConfig):
|
||||
|
||||
Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"timeout": 3600000,
|
||||
"url": "https://github.com/databrickslabs/dolly",
|
||||
"architecture": "GPTNeoXForCausalLM",
|
||||
"default_id": "databricks/dolly-v2-3b",
|
||||
"model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"],
|
||||
}
|
||||
return_full_text: bool = openllm.LLMConfig.Field(
|
||||
False, description="Whether to return the full prompt to the users."
|
||||
)
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
top_p: float = 0.92
|
||||
top_k: int = 5
|
||||
max_new_tokens: int = 256
|
||||
eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY)
|
||||
__config__ = {"timeout": 3600000, "url": "https://github.com/databrickslabs/dolly", "architecture": "GPTNeoXForCausalLM", "default_id": "databricks/dolly-v2-3b", "model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"],}
|
||||
return_full_text: bool = openllm.LLMConfig.Field(False, description="Whether to return the full prompt to the users.")
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
top_p: float = 0.92
|
||||
top_k: int = 5
|
||||
max_new_tokens: int = 256
|
||||
eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY)
|
||||
|
||||
START_DOLLY_V2_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for dolly-v2 model.
|
||||
|
||||
@@ -74,8 +69,9 @@ DEFAULT_PROMPT_TEMPLATE = """{intro}
|
||||
{instruction}
|
||||
{response_key}
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY)
|
||||
|
||||
def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int:
|
||||
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
|
||||
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
||||
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
||||
@@ -90,6 +86,6 @@ def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str)
|
||||
Returns:
|
||||
int: the token ID for the given key.
|
||||
"""
|
||||
token_ids = tokenizer.encode(key)
|
||||
if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
|
||||
return token_ids[0]
|
||||
token_ids = tokenizer.encode(key)
|
||||
if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
|
||||
return token_ids[0]
|
||||
|
||||
@@ -24,104 +24,130 @@ from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf
|
||||
else: torch, transformers, tf = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("tf", globals(), "tensorflow")
|
||||
logger = logging.getLogger(__name__)
|
||||
@t.overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline: ...
|
||||
@t.overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]: ...
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
|
||||
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
|
||||
class InstructionTextGenerationPipeline(transformers.Pipeline):
|
||||
def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
|
||||
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any):
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
preprocess_params: dict[str, t.Any] = {}
|
||||
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
|
||||
# append a newline to yield a single token. find whatever token is configured for the response key.
|
||||
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
|
||||
response_key_token_id = None
|
||||
end_key_token_id = None
|
||||
if tokenizer_response_key:
|
||||
try:
|
||||
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
|
||||
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
|
||||
# Ensure generation stops once it generates "### End"
|
||||
generate_kwargs["eos_token_id"] = end_key_token_id
|
||||
except ValueError: pass
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id}
|
||||
if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
def preprocess(self, input_: str, **generate_kwargs: t.Any):
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
|
||||
inputs = self.tokenizer(prompt_text, return_tensors="pt")
|
||||
inputs["prompt_text"] = prompt_text
|
||||
inputs["instruction_text"] = input_
|
||||
return inputs
|
||||
def _forward(self, model_inputs: dict[str, t.Any], **generate_kwargs: t.Any):
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
input_ids, attention_mask = model_inputs["input_ids"], model_inputs.get("attention_mask", None)
|
||||
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
|
||||
else: in_b = input_ids.shape[0]
|
||||
generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
instruction_text = model_inputs.pop("instruction_text")
|
||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
||||
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False):
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"]
|
||||
generated_sequence: list[list[int]] = generated_sequence.numpy().tolist()
|
||||
records: list[dict[t.Literal["generated_text"], str]] = []
|
||||
for sequence in generated_sequence:
|
||||
# The response will be set to this variable if we can identify it.
|
||||
decoded = None
|
||||
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
||||
if response_key_token_id and end_key_token_id:
|
||||
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
||||
# prompt, we should definitely find it. We will return the tokens found after this token.
|
||||
try: response_pos = sequence.index(response_key_token_id)
|
||||
except ValueError: response_pos = None
|
||||
if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
|
||||
if response_pos:
|
||||
# Next find where "### End" is located. The model has been trained to end its responses with this
|
||||
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
||||
# this token, as the response could be truncated. If we don't find it then just return everything
|
||||
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
||||
try: end_pos = sequence.index(end_key_token_id)
|
||||
except ValueError: end_pos = None
|
||||
decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
|
||||
if not decoded:
|
||||
# Otherwise we'll decode everything and use a regex to find the response and end.
|
||||
fully_decoded = self.tokenizer.decode(sequence)
|
||||
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
||||
# end.
|
||||
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else:
|
||||
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
||||
# return everything after "### Response:".
|
||||
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else: logger.warning("Failed to find response in:\n%s", fully_decoded)
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
# trained on, but to the client it will appear to be the full text.
|
||||
if return_full_text: decoded = f"{instruction_text}\n{decoded}"
|
||||
rec = {"generated_text": decoded}
|
||||
records.append(rec)
|
||||
return records
|
||||
|
||||
if _init: return InstructionTextGenerationPipeline()
|
||||
return InstructionTextGenerationPipeline
|
||||
@t.overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]:
|
||||
...
|
||||
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
|
||||
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
|
||||
class InstructionTextGenerationPipeline(transformers.Pipeline):
|
||||
def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any):
|
||||
super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any):
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
preprocess_params: dict[str, t.Any] = {}
|
||||
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
|
||||
# append a newline to yield a single token. find whatever token is configured for the response key.
|
||||
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
|
||||
response_key_token_id = None
|
||||
end_key_token_id = None
|
||||
if tokenizer_response_key:
|
||||
try:
|
||||
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
|
||||
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
|
||||
# Ensure generation stops once it generates "### End"
|
||||
generate_kwargs["eos_token_id"] = end_key_token_id
|
||||
except ValueError:
|
||||
pass
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id}
|
||||
if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def preprocess(self, input_: str, **generate_kwargs: t.Any):
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
|
||||
inputs = self.tokenizer(prompt_text, return_tensors="pt")
|
||||
inputs["prompt_text"] = prompt_text
|
||||
inputs["instruction_text"] = input_
|
||||
return inputs
|
||||
|
||||
def _forward(self, model_inputs: dict[str, t.Any], **generate_kwargs: t.Any):
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
input_ids, attention_mask = model_inputs["input_ids"], model_inputs.get("attention_mask", None)
|
||||
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
|
||||
else: in_b = input_ids.shape[0]
|
||||
generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
instruction_text = model_inputs.pop("instruction_text")
|
||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
||||
|
||||
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False):
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"]
|
||||
generated_sequence: list[list[int]] = generated_sequence.numpy().tolist()
|
||||
records: list[dict[t.Literal["generated_text"], str]] = []
|
||||
for sequence in generated_sequence:
|
||||
# The response will be set to this variable if we can identify it.
|
||||
decoded = None
|
||||
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
||||
if response_key_token_id and end_key_token_id:
|
||||
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
||||
# prompt, we should definitely find it. We will return the tokens found after this token.
|
||||
try:
|
||||
response_pos = sequence.index(response_key_token_id)
|
||||
except ValueError:
|
||||
response_pos = None
|
||||
if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
|
||||
if response_pos:
|
||||
# Next find where "### End" is located. The model has been trained to end its responses with this
|
||||
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
||||
# this token, as the response could be truncated. If we don't find it then just return everything
|
||||
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
||||
try:
|
||||
end_pos = sequence.index(end_key_token_id)
|
||||
except ValueError:
|
||||
end_pos = None
|
||||
decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip()
|
||||
if not decoded:
|
||||
# Otherwise we'll decode everything and use a regex to find the response and end.
|
||||
fully_decoded = self.tokenizer.decode(sequence)
|
||||
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
||||
# end.
|
||||
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else:
|
||||
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
||||
# return everything after "### Response:".
|
||||
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else: logger.warning("Failed to find response in:\n%s", fully_decoded)
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
# trained on, but to the client it will appear to be the full text.
|
||||
if return_full_text: decoded = f"{instruction_text}\n{decoded}"
|
||||
rec = {"generated_text": decoded}
|
||||
records.append(rec)
|
||||
return records
|
||||
|
||||
if _init: return InstructionTextGenerationPipeline()
|
||||
return InstructionTextGenerationPipeline
|
||||
|
||||
class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedTokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self): return {"device_map": "auto" if torch.cuda.is_available() else None, "torch_dtype": torch.bfloat16}, {"padding_side": "left"}
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return get_pipeline(model=transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), tokenizer=self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str: return generation_result[0]["generated_text"]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
with torch.inference_mode(): return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config())
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {"device_map": "auto" if torch.cuda.is_available() else None, "torch_dtype": torch.bfloat16}, {"padding_side": "left"}
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline:
|
||||
return get_pipeline(model=transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), tokenizer=self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str:
|
||||
return generation_result[0]["generated_text"]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
with torch.inference_mode():
|
||||
return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config())
|
||||
|
||||
@@ -17,17 +17,23 @@ import typing as t
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_falcon": ["FalconConfig", "START_FALCON_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_falcon"] = ["Falcon"]
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_falcon"] = ["Falcon"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_falcon import START_FALCON_COMMAND_DOCSTRING as START_FALCON_COMMAND_DOCSTRING
|
||||
from .configuration_falcon import FalconConfig as FalconConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_falcon import Falcon as Falcon
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_falcon import START_FALCON_COMMAND_DOCSTRING as START_FALCON_COMMAND_DOCSTRING
|
||||
from .configuration_falcon import FalconConfig as FalconConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_falcon import Falcon as Falcon
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -13,45 +13,26 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class FalconConfig(openllm.LLMConfig):
|
||||
"""Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora.
|
||||
"""Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora.
|
||||
|
||||
It is made available under the TII Falcon LLM License.
|
||||
|
||||
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"requires_gpu": True,
|
||||
"timeout": int(36e6),
|
||||
"url": "https://falconllm.tii.ae/",
|
||||
"requirements": ["einops", "xformers"],
|
||||
"architecture": "FalconForCausalLM",
|
||||
"default_id": "tiiuae/falcon-7b",
|
||||
"model_ids": [
|
||||
"tiiuae/falcon-7b",
|
||||
"tiiuae/falcon-40b",
|
||||
"tiiuae/falcon-7b-instruct",
|
||||
"tiiuae/falcon-40b-instruct",
|
||||
],
|
||||
"fine_tune_strategies": (
|
||||
{
|
||||
"adapter_type": "lora",
|
||||
"r": 64,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
"target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
|
||||
},
|
||||
),
|
||||
}
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 200
|
||||
top_k: int = 10
|
||||
num_return_sequences: int = 1
|
||||
num_beams: int = 4
|
||||
early_stopping: bool = True
|
||||
__config__ = {
|
||||
"name_type": "lowercase", "trust_remote_code": True, "requires_gpu": True, "timeout": int(36e6), "url": "https://falconllm.tii.ae/", "requirements": ["einops", "xformers"], "architecture": "FalconForCausalLM", "default_id": "tiiuae/falcon-7b", "model_ids": ["tiiuae/falcon-7b", "tiiuae/falcon-40b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct",],
|
||||
"fine_tune_strategies": ({"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none", "target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],},),
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 200
|
||||
top_k: int = 10
|
||||
num_return_sequences: int = 1
|
||||
num_beams: int = 4
|
||||
early_stopping: bool = True
|
||||
|
||||
START_FALCON_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for FalconLM model.
|
||||
|
||||
|
||||
@@ -18,21 +18,31 @@ from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() else None}, {}
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env( eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True)
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[: -len(stop_seq)]
|
||||
return [{"generated_text": result}]
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() else None}, {}
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
||||
return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{"generated_text": result}]
|
||||
|
||||
@@ -20,33 +20,47 @@ from ...utils import LazyModule
|
||||
from ...utils import is_flax_available
|
||||
from ...utils import is_tf_available
|
||||
from ...utils import is_torch_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_flan_t5": ["FlanT5Config", "START_FLAN_T5_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_flan_t5"] = ["FlanT5"]
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flan_t5"] = ["FlanT5"]
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_flax_flan_t5"] = ["FlaxFlanT5"]
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_flan_t5"] = ["FlaxFlanT5"]
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_tf_flan_t5"] = ["TFFlanT5"]
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_flan_t5"] = ["TFFlanT5"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_flan_t5 import START_FLAN_T5_COMMAND_DOCSTRING as START_FLAN_T5_COMMAND_DOCSTRING
|
||||
from .configuration_flan_t5 import FlanT5Config as FlanT5Config
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_flan_t5 import FlanT5 as FlanT5
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_flax_flan_t5 import FlaxFlanT5 as FlaxFlanT5
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_tf_flan_t5 import TFFlanT5 as TFFlanT5
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_flan_t5 import START_FLAN_T5_COMMAND_DOCSTRING as START_FLAN_T5_COMMAND_DOCSTRING
|
||||
from .configuration_flan_t5 import FlanT5Config as FlanT5Config
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flan_t5 import FlanT5 as FlanT5
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_flan_t5 import FlaxFlanT5 as FlaxFlanT5
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_flan_t5 import TFFlanT5 as TFFlanT5
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -13,32 +13,23 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class FlanT5Config(openllm.LLMConfig):
|
||||
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf).
|
||||
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf).
|
||||
|
||||
It is an enhanced version of T5 that has been finetuned in a mixture of tasks.
|
||||
|
||||
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"url": "https://huggingface.co/docs/transformers/model_doc/flan-t5",
|
||||
"default_id": "google/flan-t5-large",
|
||||
"architecture": "T5ForConditionalGeneration",
|
||||
"model_ids": [
|
||||
"google/flan-t5-small",
|
||||
"google/flan-t5-base",
|
||||
"google/flan-t5-large",
|
||||
"google/flan-t5-xl",
|
||||
"google/flan-t5-xxl",
|
||||
],
|
||||
"model_type": "seq2seq_lm",
|
||||
}
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 2048
|
||||
top_k: int = 50
|
||||
top_p: float = 0.4
|
||||
repetition_penalty = 1.0
|
||||
__config__ = {"url": "https://huggingface.co/docs/transformers/model_doc/flan-t5", "default_id": "google/flan-t5-large", "architecture": "T5ForConditionalGeneration", "model_ids": ["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl",], "model_type": "seq2seq_lm",}
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 2048
|
||||
top_k: int = 50
|
||||
top_p: float = 0.4
|
||||
repetition_penalty = 1.0
|
||||
|
||||
START_FLAN_T5_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for FLAN-T5 model.
|
||||
|
||||
|
||||
@@ -19,20 +19,28 @@ from ..._llm import LLMEmbeddings
|
||||
from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
|
||||
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
|
||||
|
||||
class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, decoder_input_ids=input_ids)
|
||||
data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0)
|
||||
embeddings.append(data.tolist())
|
||||
num_tokens += len(input_ids[0])
|
||||
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, decoder_input_ids=input_ids)
|
||||
data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0)
|
||||
embeddings.append(data.tolist())
|
||||
num_tokens += len(input_ids[0])
|
||||
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)
|
||||
|
||||
@@ -17,13 +17,18 @@ import openllm
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import transformers # noqa: F401
|
||||
|
||||
class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, decoder_start_token_id: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if decoder_start_token_id is None: decoder_start_token_id = 0
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "decoder_start_token_id": decoder_start_token_id}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
# NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation.
|
||||
decoder_start_token_id = attrs.pop("decoder_start_token_id", 0)
|
||||
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="np")["input_ids"], do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), decoder_start_token_id=decoder_start_token_id).sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, decoder_start_token_id: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if decoder_start_token_id is None: decoder_start_token_id = 0
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "decoder_start_token_id": decoder_start_token_id}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
# NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation.
|
||||
decoder_start_token_id = attrs.pop("decoder_start_token_id", 0)
|
||||
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="np")["input_ids"], do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), decoder_start_token_id=decoder_start_token_id).sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
@@ -17,8 +17,15 @@ import openllm
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import transformers # noqa: F401
|
||||
|
||||
class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="tf").input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="tf").input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
|
||||
@@ -17,17 +17,23 @@ import typing as t
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_gpt_neox": ["GPTNeoXConfig", "START_GPT_NEOX_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_gpt_neox"] = ["GPTNeoX"]
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_gpt_neox"] = ["GPTNeoX"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_gpt_neox import START_GPT_NEOX_COMMAND_DOCSTRING as START_GPT_NEOX_COMMAND_DOCSTRING
|
||||
from .configuration_gpt_neox import GPTNeoXConfig as GPTNeoXConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_gpt_neox import GPTNeoX as GPTNeoX
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_gpt_neox import START_GPT_NEOX_COMMAND_DOCSTRING as START_GPT_NEOX_COMMAND_DOCSTRING
|
||||
from .configuration_gpt_neox import GPTNeoXConfig as GPTNeoXConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_gpt_neox import GPTNeoX as GPTNeoX
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class GPTNeoXConfig(openllm.LLMConfig):
|
||||
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license.
|
||||
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license.
|
||||
|
||||
It is, to the best of our knowledge, the largest dense autoregressive model
|
||||
that has publicly available weights at the time of submission. The training and evaluation code, as well as the model weights,
|
||||
@@ -28,19 +29,13 @@ class GPTNeoXConfig(openllm.LLMConfig):
|
||||
Refer to [GPTNeoX's model card](https://huggingface.co/docs/transformers/model_doc/gpt_neox)
|
||||
for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"model_name": "gpt_neox",
|
||||
"start_name": "gpt-neox",
|
||||
"requires_gpu": True,
|
||||
"architecture": "GPTNeoXForCausalLM",
|
||||
"url": "https://github.com/EleutherAI/gpt-neox",
|
||||
"default_id": "eleutherai/gpt-neox-20b",
|
||||
"model_ids": ["eleutherai/gpt-neox-20b"],
|
||||
}
|
||||
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 100
|
||||
__config__ = {"model_name": "gpt_neox", "start_name": "gpt-neox", "requires_gpu": True, "architecture": "GPTNeoXForCausalLM", "url": "https://github.com/EleutherAI/gpt-neox", "default_id": "eleutherai/gpt-neox-20b", "model_ids": ["eleutherai/gpt-neox-20b"],}
|
||||
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 100
|
||||
|
||||
START_GPT_NEOX_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for GPTNeoX model.
|
||||
|
||||
|
||||
@@ -20,15 +20,25 @@ from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {}
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM:
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs)
|
||||
if self.config.use_half_precision: model.half()
|
||||
return model
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])))
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {}
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM:
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs)
|
||||
if self.config.use_half_precision: model.half()
|
||||
return model
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])))
|
||||
|
||||
@@ -18,26 +18,36 @@ from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
from ...utils import is_vllm_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_llama": ["LlamaConfig", "START_LLAMA_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE", "PROMPT_MAPPING"]}
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_vllm_llama"] = ["VLLMLlama"]
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_llama"] = ["VLLMLlama"]
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_llama"] = ["Llama"]
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_llama"] = ["Llama"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_llama import PROMPT_MAPPING as PROMPT_MAPPING
|
||||
from .configuration_llama import START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING
|
||||
from .configuration_llama import LlamaConfig as LlamaConfig
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_vllm_llama import VLLMLlama as VLLMLlama
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_llama import Llama as Llama
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_llama import PROMPT_MAPPING as PROMPT_MAPPING
|
||||
from .configuration_llama import START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING
|
||||
from .configuration_llama import LlamaConfig as LlamaConfig
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_vllm_llama import VLLMLlama as VLLMLlama
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_llama import Llama as Llama
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
import openllm
|
||||
|
||||
class LlamaConfig(openllm.LLMConfig):
|
||||
"""LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
|
||||
"""LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
|
||||
|
||||
It is a collection of foundation language models ranging from 7B to 65B parameters.
|
||||
|
||||
@@ -26,54 +27,24 @@ class LlamaConfig(openllm.LLMConfig):
|
||||
Refer to [Llama's model card](https://huggingface.co/docs/transformers/main/model_doc/llama)
|
||||
for more information.
|
||||
"""
|
||||
use_llama2_prompt: bool = openllm.LLMConfig.Field(True, description="Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.")
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"url": "https://github.com/facebookresearch/llama",
|
||||
"default_id": "huggyllama/llama-7b",
|
||||
"default_implementation": {"cpu": "pt", "nvidia.com/gpu": "pt"},
|
||||
"architecture": "LlamaForCausalLM",
|
||||
"requirements": ["fairscale", "sentencepiece"],
|
||||
"model_ids": [
|
||||
"meta-llama/Llama-2-70b-chat-hf",
|
||||
"meta-llama/Llama-2-13b-chat-hf",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
"meta-llama/Llama-2-70b-hf",
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"NousResearch/llama-2-70b-chat-hf",
|
||||
"NousResearch/llama-2-13b-chat-hf",
|
||||
"NousResearch/llama-2-7b-chat-hf",
|
||||
"NousResearch/llama-2-70b-hf",
|
||||
"NousResearch/llama-2-13b-hf",
|
||||
"NousResearch/llama-2-7b-hf",
|
||||
"openlm-research/open_llama_7b_v2",
|
||||
"openlm-research/open_llama_3b_v2",
|
||||
"openlm-research/open_llama_13b",
|
||||
"huggyllama/llama-65b",
|
||||
"huggyllama/llama-30b",
|
||||
"huggyllama/llama-13b",
|
||||
"huggyllama/llama-7b",
|
||||
],
|
||||
"tokenizer_class": "LlamaTokenizerFast",
|
||||
"fine_tune_strategies": (
|
||||
{
|
||||
"adapter_type": "lora",
|
||||
"r": 64,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
},
|
||||
),
|
||||
}
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 256
|
||||
temperature: float = 0.45
|
||||
top_p: float = 0.95
|
||||
top_k: int = 12
|
||||
class SamplingParams:
|
||||
best_of: int = 1
|
||||
presence_penalty: float = 0.5
|
||||
use_llama2_prompt: bool = openllm.LLMConfig.Field(True, description="Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.")
|
||||
__config__ = {
|
||||
"name_type": "lowercase", "url": "https://github.com/facebookresearch/llama", "default_id": "huggyllama/llama-7b", "default_implementation": {"cpu": "pt", "nvidia.com/gpu": "pt"}, "architecture": "LlamaForCausalLM", "requirements": ["fairscale", "sentencepiece"], "model_ids": [
|
||||
"meta-llama/Llama-2-70b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-13b-hf", "meta-llama/Llama-2-7b-hf", "NousResearch/llama-2-70b-chat-hf", "NousResearch/llama-2-13b-chat-hf", "NousResearch/llama-2-7b-chat-hf", "NousResearch/llama-2-70b-hf", "NousResearch/llama-2-13b-hf",
|
||||
"NousResearch/llama-2-7b-hf", "openlm-research/open_llama_7b_v2", "openlm-research/open_llama_3b_v2", "openlm-research/open_llama_13b", "huggyllama/llama-65b", "huggyllama/llama-30b", "huggyllama/llama-13b", "huggyllama/llama-7b",
|
||||
], "tokenizer_class": "LlamaTokenizerFast", "fine_tune_strategies": ({"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none",},),
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 256
|
||||
temperature: float = 0.45
|
||||
top_p: float = 0.95
|
||||
top_k: int = 12
|
||||
|
||||
class SamplingParams:
|
||||
best_of: int = 1
|
||||
presence_penalty: float = 0.5
|
||||
|
||||
START_LLAMA_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for Llama model.
|
||||
|
||||
@@ -112,5 +83,7 @@ SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = "[INST]", "[/INST]", "<<SY
|
||||
# TODO: support history and v1 prompt implementation
|
||||
_v1_prompt, _v2_prompt = """{instruction}""", """{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key} """.format(start_key=SINST_KEY, sys_key=SYS_KEY, system_message=SYSTEM_MESSAGE, instruction="{instruction}", end_key=EINST_KEY)
|
||||
PROMPT_MAPPING = {"v1": _v1_prompt, "v2": _v2_prompt}
|
||||
|
||||
def _get_prompt(model_type: t.Literal["v1", "v2"]) -> str: return PROMPT_MAPPING[model_type]
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
|
||||
@@ -21,22 +21,28 @@ from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
|
||||
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None
|
||||
return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), generation_config=self.config.model_construct_env(**attrs).to_generation_config(), stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])), skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
|
||||
encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device)
|
||||
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
||||
with torch.inference_mode():
|
||||
data = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
|
||||
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
|
||||
return LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item())
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None
|
||||
return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), generation_config=self.config.model_construct_env(**attrs).to_generation_config(), stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])), skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
|
||||
encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device)
|
||||
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
||||
with torch.inference_mode():
|
||||
data = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
|
||||
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
|
||||
return LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item())
|
||||
|
||||
@@ -19,8 +19,10 @@ from .configuration_llama import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VLLMLlama(openllm.LLM["vllm.LLMEngine", "transformers.LlamaTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None
|
||||
return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None
|
||||
return process_prompt(prompt, _template, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
|
||||
|
||||
@@ -17,18 +17,24 @@ import typing as t
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_mpt": ["MPTConfig", "START_MPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE", "PROMPT_MAPPING"]}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_mpt"] = ["MPT"]
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_mpt"] = ["MPT"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_mpt import PROMPT_MAPPING as PROMPT_MAPPING
|
||||
from .configuration_mpt import START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING
|
||||
from .configuration_mpt import MPTConfig as MPTConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_mpt import MPT as MPT
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_mpt import PROMPT_MAPPING as PROMPT_MAPPING
|
||||
from .configuration_mpt import START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING
|
||||
from .configuration_mpt import MPTConfig as MPTConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_mpt import MPT as MPT
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -16,8 +16,9 @@ import typing as t
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: MPTPromptType = t.Literal["default", "instruct", "chat", "storywriter"]
|
||||
else: MPTPromptType = str
|
||||
|
||||
class MPTConfig(openllm.LLMConfig):
|
||||
"""MPT is a decoder-style transformer pretrained from scratch on English text and code.
|
||||
"""MPT is a decoder-style transformer pretrained from scratch on English text and code.
|
||||
|
||||
This model was trained by [MosaicML](https://www.mosaicml.com/).
|
||||
|
||||
@@ -25,30 +26,18 @@ class MPTConfig(openllm.LLMConfig):
|
||||
on HuggingFace. Refers [HuggingFace's MosaicML page](https://huggingface.co/mosaicml)
|
||||
for more details on specific models.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"url": "https://huggingface.co/mosaicml",
|
||||
"default_id": "mosaicml/mpt-7b-instruct",
|
||||
"timeout": int(36e6),
|
||||
"requirements": ["triton", "einops"],
|
||||
"architecture": "MPTForCausalLM",
|
||||
"model_ids": [
|
||||
"mosaicml/mpt-7b",
|
||||
"mosaicml/mpt-7b-instruct",
|
||||
"mosaicml/mpt-7b-chat",
|
||||
"mosaicml/mpt-7b-storywriter",
|
||||
"mosaicml/mpt-30b",
|
||||
"mosaicml/mpt-30b-instruct",
|
||||
"mosaicml/mpt-30b-chat",
|
||||
],
|
||||
}
|
||||
prompt_type: MPTPromptType = openllm.LLMConfig.Field('"default"', description="""Given prompt type for running MPT. Default will be inferred from model name if pretrained.""")
|
||||
max_sequence_length: int = openllm.LLMConfig.Field(2048, description="Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)")
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0
|
||||
top_p: float = 0.8
|
||||
__config__ = {
|
||||
"name_type": "lowercase", "trust_remote_code": True, "url": "https://huggingface.co/mosaicml", "default_id": "mosaicml/mpt-7b-instruct", "timeout": int(36e6), "requirements": ["triton", "einops"], "architecture": "MPTForCausalLM",
|
||||
"model_ids": ["mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct", "mosaicml/mpt-7b-chat", "mosaicml/mpt-7b-storywriter", "mosaicml/mpt-30b", "mosaicml/mpt-30b-instruct", "mosaicml/mpt-30b-chat",],
|
||||
}
|
||||
prompt_type: MPTPromptType = openllm.LLMConfig.Field('"default"', description="""Given prompt type for running MPT. Default will be inferred from model name if pretrained.""")
|
||||
max_sequence_length: int = openllm.LLMConfig.Field(2048, description="Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)")
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0
|
||||
top_p: float = 0.8
|
||||
|
||||
START_MPT_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for MPT model.
|
||||
|
||||
@@ -86,5 +75,8 @@ _chat_prompt, _default_prompt, _instruct_prompt = """{instruction}""", """{instr
|
||||
{response_key}
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY)
|
||||
PROMPT_MAPPING = {"default": _default_prompt, "instruct": _instruct_prompt, "storywriter": _default_prompt, "chat": _chat_prompt}
|
||||
def _get_prompt(model_type: str) -> str: return PROMPT_MAPPING[model_type]
|
||||
|
||||
def _get_prompt(model_type: str) -> str:
|
||||
return PROMPT_MAPPING[model_type]
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
|
||||
@@ -23,56 +23,71 @@ from ...utils import generate_labels, is_triton_available
|
||||
if t.TYPE_CHECKING: import transformers, torch
|
||||
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_mpt_config(model_id_or_path: str, max_sequence_length: int, device: torch.device | str | int | None, device_map: str | None = None, trust_remote_code: bool = True) -> transformers.PretrainedConfig:
|
||||
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
|
||||
if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
|
||||
if hasattr(config, "attn_config") and is_triton_available(): config.attn_config["attn_impl"] = "triton"
|
||||
else: logger.debug("'triton' is not available, Flash Attention will use the default Torch implementation. For faster inference, make sure to install triton with 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'")
|
||||
# setting max_seq_len
|
||||
config.max_seq_len = max_sequence_length
|
||||
return config
|
||||
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
|
||||
if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
|
||||
if hasattr(config, "attn_config") and is_triton_available(): config.attn_config["attn_impl"] = "triton"
|
||||
else: logger.debug("'triton' is not available, Flash Attention will use the default Torch implementation. For faster inference, make sure to install triton with 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'")
|
||||
# setting max_seq_len
|
||||
config.max_seq_len = max_sequence_length
|
||||
return config
|
||||
|
||||
class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def llm_post_init(self): self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"}
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
device_map = attrs.pop("device_map", None)
|
||||
attrs.pop("low_cpu_mem_usage", None)
|
||||
config = get_mpt_config(self.model_id, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map, **attrs)
|
||||
try: return bentoml.transformers.save_model( self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
finally: torch.cuda.empty_cache()
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel:
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
device_map = attrs.pop("device_map", None)
|
||||
trust_remote_code = attrs.pop("trust_remote_code", True)
|
||||
config = get_mpt_config(self._bentomodel.path, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code,)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, config=config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
model.tie_weights()
|
||||
return model
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, prompt_type: MPTPromptType | None = None, use_default_prompt_template: bool = True, **attrs: t.Any,) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = None
|
||||
if use_default_prompt_template:
|
||||
if prompt_type is None:
|
||||
if "instruct" in self.model_id: prompt_type = "instruct"
|
||||
elif "storywriter" in self.model_id: prompt_type = "storywriter"
|
||||
elif "chat" in self.model_id: prompt_type = "chat"
|
||||
else: prompt_type = "default"
|
||||
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
|
||||
return process_prompt(prompt, _template, use_default_prompt_template), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()}
|
||||
with torch.inference_mode():
|
||||
if torch.cuda.is_available():
|
||||
with torch.autocast("cuda", torch.float16):
|
||||
generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
else: generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True)
|
||||
__openllm_internal__ = True
|
||||
|
||||
def llm_post_init(self):
|
||||
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
device_map = attrs.pop("device_map", None)
|
||||
attrs.pop("low_cpu_mem_usage", None)
|
||||
config = get_mpt_config(self.model_id, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map, **attrs)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel:
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
device_map = attrs.pop("device_map", None)
|
||||
trust_remote_code = attrs.pop("trust_remote_code", True)
|
||||
config = get_mpt_config(self._bentomodel.path, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code,)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, config=config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
model.tie_weights()
|
||||
return model
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, prompt_type: MPTPromptType | None = None, use_default_prompt_template: bool = True, **attrs: t.Any,) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = None
|
||||
if use_default_prompt_template:
|
||||
if prompt_type is None:
|
||||
if "instruct" in self.model_id: prompt_type = "instruct"
|
||||
elif "storywriter" in self.model_id: prompt_type = "storywriter"
|
||||
elif "chat" in self.model_id: prompt_type = "chat"
|
||||
else: prompt_type = "default"
|
||||
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
|
||||
return process_prompt(prompt, _template, use_default_prompt_template), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()}
|
||||
with torch.inference_mode():
|
||||
if torch.cuda.is_available():
|
||||
with torch.autocast("cuda", torch.float16):
|
||||
generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
else:
|
||||
generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True)
|
||||
|
||||
@@ -20,41 +20,59 @@ from ...utils import is_flax_available
|
||||
from ...utils import is_tf_available
|
||||
from ...utils import is_torch_available
|
||||
from ...utils import is_vllm_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_opt": ["OPTConfig", "START_OPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_opt"] = ["OPT"]
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_opt"] = ["OPT"]
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_flax_opt"] = ["FlaxOPT"]
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_opt"] = ["FlaxOPT"]
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_vllm_opt"] = ["VLLMOPT"]
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_opt"] = ["VLLMOPT"]
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_tf_opt"] = ["TFOPT"]
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_opt"] = ["TFOPT"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_opt import START_OPT_COMMAND_DOCSTRING as START_OPT_COMMAND_DOCSTRING
|
||||
from .configuration_opt import OPTConfig as OPTConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_opt import OPT as OPT
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_flax_opt import FlaxOPT as FlaxOPT
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_vllm_opt import VLLMOPT as VLLMOPT
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_tf_opt import TFOPT as TFOPT
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_opt import START_OPT_COMMAND_DOCSTRING as START_OPT_COMMAND_DOCSTRING
|
||||
from .configuration_opt import OPTConfig as OPTConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_opt import OPT as OPT
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_opt import FlaxOPT as FlaxOPT
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_vllm_opt import VLLMOPT as VLLMOPT
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_opt import TFOPT as TFOPT
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class OPTConfig(openllm.LLMConfig):
|
||||
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI.
|
||||
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI.
|
||||
|
||||
OPT was predominantly pretrained with English text, but a small amount of non-English data is still present
|
||||
within the training corpus via CommonCrawl. The model was pretrained using a causal language modeling (CLM)
|
||||
@@ -23,37 +24,18 @@ class OPTConfig(openllm.LLMConfig):
|
||||
|
||||
Refer to [OPT's HuggingFace page](https://huggingface.co/docs/transformers/model_doc/opt) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": False,
|
||||
"url": "https://huggingface.co/docs/transformers/model_doc/opt",
|
||||
"default_id": "facebook/opt-1.3b",
|
||||
"architecture": "OPTForCausalLM",
|
||||
"model_ids": [
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-2.7b",
|
||||
"facebook/opt-6.7b",
|
||||
"facebook/opt-66b",
|
||||
],
|
||||
"fine_tune_strategies": (
|
||||
{
|
||||
"adapter_type": "lora",
|
||||
"r": 16,
|
||||
"lora_alpha": 32,
|
||||
"target_modules": ["q_proj", "v_proj"],
|
||||
"lora_dropout": 0.05,
|
||||
"bias": "none",
|
||||
},
|
||||
),
|
||||
}
|
||||
format_outputs: bool = openllm.LLMConfig.Field(False, description="""Whether to format the outputs. This can be used when num_return_sequences > 1.""")
|
||||
class GenerationConfig:
|
||||
top_k: int = 15
|
||||
temperature: float = 0.75
|
||||
max_new_tokens: int = 1024
|
||||
num_return_sequences: int = 1
|
||||
__config__ = {
|
||||
"name_type": "lowercase", "trust_remote_code": False, "url": "https://huggingface.co/docs/transformers/model_doc/opt", "default_id": "facebook/opt-1.3b", "architecture": "OPTForCausalLM", "model_ids": ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b", "facebook/opt-66b",],
|
||||
"fine_tune_strategies": ({"adapter_type": "lora", "r": 16, "lora_alpha": 32, "target_modules": ["q_proj", "v_proj"], "lora_dropout": 0.05, "bias": "none",},),
|
||||
}
|
||||
format_outputs: bool = openllm.LLMConfig.Field(False, description="""Whether to format the outputs. This can be used when num_return_sequences > 1.""")
|
||||
|
||||
class GenerationConfig:
|
||||
top_k: int = 15
|
||||
temperature: float = 0.75
|
||||
max_new_tokens: int = 1024
|
||||
num_return_sequences: int = 1
|
||||
|
||||
START_OPT_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for OPT model.
|
||||
|
||||
|
||||
@@ -22,18 +22,26 @@ from ...utils import generate_labels
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
logger = logging.getLogger(__name__)
|
||||
class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {}, {"padding_side": "left", "truncation_side": "left"}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences, "repetition_penalty": repetition_penalty}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode( self.model.generate(**self.tokenizer(prompt, return_tensors="np"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences, skip_special_tokens=True)
|
||||
class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {}, {"padding_side": "left", "truncation_side": "left"}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences, "repetition_penalty": repetition_penalty}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="np"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences, skip_special_tokens=True)
|
||||
|
||||
@@ -18,23 +18,34 @@ import openllm
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING:
|
||||
import torch, transformers
|
||||
import torch, transformers
|
||||
else:
|
||||
torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
def llm_post_init(self): self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left", "truncation_side": "left"}
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM:
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, torch_dtype=torch_dtype, **attrs)
|
||||
return model
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
__openllm_internal__ = True
|
||||
|
||||
def llm_post_init(self):
|
||||
self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left", "truncation_side": "left"}
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM:
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, torch_dtype=torch_dtype, **attrs)
|
||||
return model
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
|
||||
@@ -22,17 +22,26 @@ from ...utils import generate_labels
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {}, {"padding_side": "left", "truncation_side": "left"}
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(self.tag, transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="tf"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {}, {"padding_side": "left", "truncation_side": "left"}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(self.tag, transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="tf"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
|
||||
@@ -19,9 +19,13 @@ import openllm
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING:
|
||||
import vllm, transformers
|
||||
import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VLLMOPT(openllm.LLM["vllm.LLMEngine", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
|
||||
|
||||
@@ -17,17 +17,23 @@ import typing as t
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_stablelm": ["StableLMConfig", "START_STABLELM_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_stablelm"] = ["StableLM"]
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_stablelm"] = ["StableLM"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_stablelm import START_STABLELM_COMMAND_DOCSTRING as START_STABLELM_COMMAND_DOCSTRING
|
||||
from .configuration_stablelm import StableLMConfig as StableLMConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_stablelm import StableLM as StableLM
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_stablelm import START_STABLELM_COMMAND_DOCSTRING as START_STABLELM_COMMAND_DOCSTRING
|
||||
from .configuration_stablelm import StableLMConfig as StableLMConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_stablelm import StableLM as StableLM
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class StableLMConfig(openllm.LLMConfig):
|
||||
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models.
|
||||
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models.
|
||||
|
||||
It is pre-trained on a diverse collection of English datasets with a sequence
|
||||
length of 4096 to push beyond the context window limitations of existing open-source language models.
|
||||
@@ -27,23 +28,14 @@ class StableLMConfig(openllm.LLMConfig):
|
||||
and [StableLM-base's model card](https://huggingface.co/stabilityai/stablelm-base-alpha-7b)
|
||||
for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"url": "https://github.com/Stability-AI/StableLM",
|
||||
"architecture": "GPTNeoXForCausalLM",
|
||||
"default_id": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
"model_ids": [
|
||||
"stabilityai/stablelm-tuned-alpha-3b",
|
||||
"stabilityai/stablelm-tuned-alpha-7b",
|
||||
"stabilityai/stablelm-base-alpha-3b",
|
||||
"stabilityai/stablelm-base-alpha-7b",
|
||||
],
|
||||
}
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 128
|
||||
top_k: int = 0
|
||||
top_p: float = 0.9
|
||||
__config__ = {"name_type": "lowercase", "url": "https://github.com/Stability-AI/StableLM", "architecture": "GPTNeoXForCausalLM", "default_id": "stabilityai/stablelm-tuned-alpha-3b", "model_ids": ["stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", "stabilityai/stablelm-base-alpha-3b", "stabilityai/stablelm-base-alpha-7b",],}
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 128
|
||||
top_k: int = 0
|
||||
top_p: float = 0.9
|
||||
|
||||
START_STABLELM_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for StableLM model.
|
||||
|
||||
|
||||
@@ -21,17 +21,28 @@ from ..._prompt import process_prompt
|
||||
if t.TYPE_CHECKING: import transformers, torch
|
||||
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def llm_post_init(self): self.bettertransformer = True if not torch.cuda.is_available() else False
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if "tuned" in self._model_id and use_default_prompt_template:
|
||||
system_prompt = attrs.pop("system_prompt", SYSTEM_PROMPT)
|
||||
prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs)
|
||||
else: prompt_text = prompt
|
||||
return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode(): return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))[0], skip_special_tokens=True)]
|
||||
__openllm_internal__ = True
|
||||
|
||||
def llm_post_init(self):
|
||||
self.bettertransformer = True if not torch.cuda.is_available() else False
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if "tuned" in self._model_id and use_default_prompt_template:
|
||||
system_prompt = attrs.pop("system_prompt", SYSTEM_PROMPT)
|
||||
prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode():
|
||||
return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))[0], skip_special_tokens=True)]
|
||||
|
||||
@@ -17,17 +17,23 @@ import typing as t
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_starcoder": ["StarCoderConfig", "START_STARCODER_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: _import_structure["modeling_starcoder"] = ["StarCoder"]
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_starcoder"] = ["StarCoder"]
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_starcoder import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_starcoder import START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING
|
||||
from .configuration_starcoder import StarCoderConfig as StarCoderConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
else: from .modeling_starcoder import StarCoder as StarCoder
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
from .configuration_starcoder import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_starcoder import START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING
|
||||
from .configuration_starcoder import StarCoderConfig as StarCoderConfig
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
from .modeling_starcoder import StarCoder as StarCoder
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class StarCoderConfig(openllm.LLMConfig):
|
||||
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
|
||||
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
|
||||
|
||||
The model uses [Multi Query Attention](https://arxiv.org/abs/1911.02150),
|
||||
[a context window of 8192 tokens](https://arxiv.org/abs/2205.14135), and was trained using the
|
||||
@@ -22,24 +23,17 @@ class StarCoderConfig(openllm.LLMConfig):
|
||||
|
||||
Refer to [StarCoder's model card](https://huggingface.co/bigcode/starcoder) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"requires_gpu": True,
|
||||
"url": "https://github.com/bigcode-project/starcoder",
|
||||
"architecture": "GPTBigCodeForCausalLM",
|
||||
"requirements": ["bitsandbytes"],
|
||||
"workers_per_resource": 0.5,
|
||||
"default_id": "bigcode/starcoder",
|
||||
"model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"],
|
||||
}
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.2
|
||||
max_new_tokens: int = 256
|
||||
min_new_tokens: int = 32
|
||||
top_k: float = 50
|
||||
top_p: float = 0.95
|
||||
pad_token_id: int = 49152
|
||||
repetition_penalty: float = 1.2
|
||||
__config__ = {"name_type": "lowercase", "requires_gpu": True, "url": "https://github.com/bigcode-project/starcoder", "architecture": "GPTBigCodeForCausalLM", "requirements": ["bitsandbytes"], "workers_per_resource": 0.5, "default_id": "bigcode/starcoder", "model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"],}
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.2
|
||||
max_new_tokens: int = 256
|
||||
min_new_tokens: int = 32
|
||||
top_k: float = 50
|
||||
top_p: float = 0.95
|
||||
pad_token_id: int = 49152
|
||||
repetition_penalty: float = 1.2
|
||||
|
||||
START_STARCODER_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for StarCoder model.
|
||||
|
||||
|
||||
@@ -18,47 +18,61 @@ import bentoml
|
||||
import openllm
|
||||
from ...utils import generate_labels
|
||||
if t.TYPE_CHECKING:
|
||||
import torch, transformers
|
||||
import torch, transformers
|
||||
else:
|
||||
torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
logger = logging.getLogger(__name__)
|
||||
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = "<fim-prefix>", "<fim-middle>", "<fim-suffix>", "<fim-pad>", "<|endoftext|>", "<FILL_HERE>"
|
||||
|
||||
class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"}
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
try: return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
finally: torch.cuda.empty_cache()
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
|
||||
if fim_mode:
|
||||
try: prefix, suffix = prompt.split(FIM_INDICATOR)
|
||||
except Exception as err: raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
|
||||
prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
|
||||
else: prompt_text = prompt
|
||||
# XXX: This value for pad_token_id is currently a hack, need more investigate why the
|
||||
# default starcoder doesn't include the same value as santacoder EOD
|
||||
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode():
|
||||
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
|
||||
# NOTE: support fine-tuning starcoder
|
||||
result_tensor = self.model.generate(self.tokenizer.encode(prompt, return_tensors="pt").to(self.device), do_sample=True, pad_token_id=self.tokenizer.eos_token_id, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
# TODO: We will probably want to return the tokenizer here so that we can manually process this
|
||||
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
|
||||
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[: -len(stop_seq)]
|
||||
return [{"generated_text": result}]
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {"padding_side": "left"}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
|
||||
if fim_mode:
|
||||
try:
|
||||
prefix, suffix = prompt.split(FIM_INDICATOR)
|
||||
except Exception as err:
|
||||
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
|
||||
prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
|
||||
else:
|
||||
prompt_text = prompt
|
||||
# XXX: This value for pad_token_id is currently a hack, need more investigate why the
|
||||
# default starcoder doesn't include the same value as santacoder EOD
|
||||
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode():
|
||||
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
|
||||
# NOTE: support fine-tuning starcoder
|
||||
result_tensor = self.model.generate(self.tokenizer.encode(prompt, return_tensors="pt").to(self.device), do_sample=True, pad_token_id=self.tokenizer.eos_token_id, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
# TODO: We will probably want to return the tokenizer here so that we can manually process this
|
||||
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
|
||||
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{"generated_text": result}]
|
||||
|
||||
@@ -11,7 +11,6 @@ import torch
|
||||
import openllm
|
||||
import transformers
|
||||
|
||||
|
||||
# Make sure to have at least one GPU to run this script
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
@@ -21,90 +20,54 @@ logger = logging.getLogger(__name__)
|
||||
# On notebook, make sure to install the following
|
||||
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
DEFAULT_MODEL_ID = "ybelkada/falcon-7b-sharded-bf16"
|
||||
DATASET_NAME = "timdettmers/openassistant-guanaco"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
optim: str = dataclasses.field(default="paged_adamw_32bit")
|
||||
save_steps: int = dataclasses.field(default=10)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=500)
|
||||
logging_steps: int = dataclasses.field(default=10)
|
||||
learning_rate: float = dataclasses.field(default=2e-4)
|
||||
max_grad_norm: float = dataclasses.field(default=0.3)
|
||||
warmup_ratio: float = dataclasses.field(default=0.03)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
group_by_length: bool = dataclasses.field(default=True)
|
||||
lr_scheduler_type: str = dataclasses.field(default="constant")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "falcon"))
|
||||
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
optim: str = dataclasses.field(default="paged_adamw_32bit")
|
||||
save_steps: int = dataclasses.field(default=10)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=500)
|
||||
logging_steps: int = dataclasses.field(default=10)
|
||||
learning_rate: float = dataclasses.field(default=2e-4)
|
||||
max_grad_norm: float = dataclasses.field(default=0.3)
|
||||
warmup_ratio: float = dataclasses.field(default=0.03)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
group_by_length: bool = dataclasses.field(default=True)
|
||||
lr_scheduler_type: str = dataclasses.field(default="constant")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "falcon"))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
max_sequence_length: int = dataclasses.field(default=512)
|
||||
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
max_sequence_length: int = dataclasses.field(default=512)
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(
|
||||
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
|
||||
)
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
model, tokenizer = openllm.AutoLLM.for_model(
|
||||
"falcon",
|
||||
model_id=model_args.model_id,
|
||||
quantize="int4",
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
ensure_available=True,
|
||||
).prepare_for_training(
|
||||
adapter_type="lora",
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias="none",
|
||||
target_modules=[
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
],
|
||||
)
|
||||
model, tokenizer = openllm.AutoLLM.for_model("falcon", model_id=model_args.model_id, quantize="int4", bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ensure_available=True,).prepare_for_training(adapter_type="lora", lora_alpha=16, lora_dropout=0.1, r=16, bias="none", target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h",],)
|
||||
model.config.use_cache = False
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split="train")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=model_args.max_sequence_length,
|
||||
tokenizer=tokenizer,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir),
|
||||
**dataclasses.asdict(training_args),
|
||||
),
|
||||
)
|
||||
trainer = SFTTrainer(model=model, train_dataset=dataset, dataset_text_field="text", max_seq_length=model_args.max_sequence_length, tokenizer=tokenizer, args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args),),)
|
||||
|
||||
# upcast layernorm in float32 for more stable training
|
||||
for name, module in trainer.model.named_modules():
|
||||
if "norm" in name:
|
||||
module = module.to(torch.float32)
|
||||
if "norm" in name:
|
||||
module = module.to(torch.float32)
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -15,45 +14,42 @@ MAX_NEW_TOKENS = 384
|
||||
Q = "Answer the following question, step by step:\n{q}\nA:"
|
||||
question = "What is the meaning of life?"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("question", default=question)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("question", default=question)
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
args = parser.parse_args(args=[question])
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
if openllm.utils.in_notebook():
|
||||
args = parser.parse_args(args=[question])
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
model = openllm.AutoLLM.for_model("opt", model_id="facebook/opt-2.7b", ensure_available=True)
|
||||
prompt = Q.format(q=args.question)
|
||||
model = openllm.AutoLLM.for_model("opt", model_id="facebook/opt-2.7b", ensure_available=True)
|
||||
prompt = Q.format(q=args.question)
|
||||
|
||||
logger.info("-" * 50, "Running with 'generate()'", "-" * 50)
|
||||
res = model.generate(prompt, max_new_tokens=MAX_NEW_TOKENS)
|
||||
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
|
||||
logger.info("-" * 50, "Running with 'generate()'", "-" * 50)
|
||||
res = model.generate(prompt, max_new_tokens=MAX_NEW_TOKENS)
|
||||
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
|
||||
|
||||
logger.info("-" * 50, "Running with 'generate()' with per-requests argument", "-" * 50)
|
||||
res = model.generate(prompt, num_return_sequences=3)
|
||||
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
|
||||
logger.info("-" * 50, "Running with 'generate()' with per-requests argument", "-" * 50)
|
||||
res = model.generate(prompt, num_return_sequences=3)
|
||||
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
|
||||
|
||||
logger.info("-" * 50, "Using Runner abstraction with runner.generate.run()", "-" * 50)
|
||||
r = openllm.Runner("opt", model_id="facebook/opt-350m", init_local=True)
|
||||
res = r.generate.run(prompt)
|
||||
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
|
||||
logger.info("-" * 50, "Using Runner abstraction with runner.generate.run()", "-" * 50)
|
||||
r = openllm.Runner("opt", model_id="facebook/opt-350m", init_local=True)
|
||||
res = r.generate.run(prompt)
|
||||
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
|
||||
|
||||
logger.info("-" * 50, "Using Runner abstraction with runner()", "-" * 50)
|
||||
res = r(prompt)
|
||||
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
|
||||
|
||||
return 0
|
||||
logger.info("-" * 50, "Using Runner abstraction with runner()", "-" * 50)
|
||||
res = r(prompt)
|
||||
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
|
||||
|
||||
return 0
|
||||
|
||||
def _mp_fn(index: t.Any): # noqa # type: ignore
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
main()
|
||||
main()
|
||||
else:
|
||||
raise SystemExit(main())
|
||||
raise SystemExit(main())
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
import transformers
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
import peft
|
||||
|
||||
# Make sure to have at least one GPU to run this script
|
||||
|
||||
@@ -29,19 +29,17 @@ from itertools import chain
|
||||
from functools import partial
|
||||
from random import randrange
|
||||
|
||||
|
||||
# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
|
||||
def find_all_linear_names(model):
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, bnb.nn.Linear4bit):
|
||||
names = name.split(".")
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove("lm_head")
|
||||
return list(lora_module_names)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, bnb.nn.Linear4bit):
|
||||
names = name.split(".")
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove("lm_head")
|
||||
return list(lora_module_names)
|
||||
|
||||
# Change this to the local converted path if you don't have access to the meta-llama model
|
||||
DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf"
|
||||
@@ -49,202 +47,149 @@ DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf"
|
||||
DEFAULT_MODEL_VERSION = "335a02887eb6684d487240bbc28b5699298c3135"
|
||||
DATASET_NAME = "databricks/databricks-dolly-15k"
|
||||
|
||||
|
||||
def format_dolly(sample):
|
||||
instruction = f"### Instruction\n{sample['instruction']}"
|
||||
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
|
||||
response = f"### Answer\n{sample['response']}"
|
||||
# join all the parts together
|
||||
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
|
||||
return prompt
|
||||
|
||||
instruction = f"### Instruction\n{sample['instruction']}"
|
||||
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
|
||||
response = f"### Answer\n{sample['response']}"
|
||||
# join all the parts together
|
||||
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
|
||||
return prompt
|
||||
|
||||
# template dataset to add prompt to each sample
|
||||
def template_dataset(sample, tokenizer):
|
||||
sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
|
||||
return sample
|
||||
|
||||
sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
|
||||
return sample
|
||||
|
||||
# empty list to save remainder from batches to use in next batch
|
||||
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}
|
||||
|
||||
|
||||
def chunk(sample, chunk_length=2048):
|
||||
# define global remainder variable to save remainder from batches to use in next batch
|
||||
global remainder
|
||||
# Concatenate all texts and add remainder from previous batch
|
||||
concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
|
||||
concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
|
||||
# get total number of tokens for batch
|
||||
batch_total_length = len(concatenated_examples[list(sample.keys())[0]])
|
||||
# define global remainder variable to save remainder from batches to use in next batch
|
||||
global remainder
|
||||
# Concatenate all texts and add remainder from previous batch
|
||||
concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
|
||||
concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
|
||||
# get total number of tokens for batch
|
||||
batch_total_length = len(concatenated_examples[list(sample.keys())[0]])
|
||||
|
||||
# get max number of chunks for batch
|
||||
if batch_total_length >= chunk_length:
|
||||
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
|
||||
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
# add remainder to global variable for next batch
|
||||
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
|
||||
# prepare labels
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
# get max number of chunks for batch
|
||||
if batch_total_length >= chunk_length:
|
||||
batch_chunk_length = (batch_total_length//chunk_length) * chunk_length
|
||||
|
||||
# Split by chunks of max_len.
|
||||
result = {k: [t[i:i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] for k, t in concatenated_examples.items()}
|
||||
# add remainder to global variable for next batch
|
||||
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
|
||||
# prepare labels
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
|
||||
# Load dataset from the hub
|
||||
dataset = load_dataset(dataset_name, split="train")
|
||||
# Load dataset from the hub
|
||||
dataset = load_dataset(dataset_name, split="train")
|
||||
|
||||
print(f"dataset size: {len(dataset)}")
|
||||
print(dataset[randrange(len(dataset))])
|
||||
print(f"dataset size: {len(dataset)}")
|
||||
print(dataset[randrange(len(dataset))])
|
||||
|
||||
# apply prompt template per sample
|
||||
dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features))
|
||||
# print random sample
|
||||
print("Sample from dolly-v2 ds:", dataset[randint(0, len(dataset))]["text"])
|
||||
# apply prompt template per sample
|
||||
dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features))
|
||||
# print random sample
|
||||
print("Sample from dolly-v2 ds:", dataset[randint(0, len(dataset))]["text"])
|
||||
|
||||
# tokenize and chunk dataset
|
||||
lm_dataset = dataset.map(
|
||||
lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)
|
||||
).map(
|
||||
partial(chunk, chunk_length=2048),
|
||||
batched=True,
|
||||
)
|
||||
|
||||
# Print total number of samples
|
||||
print(f"Total number of samples: {len(lm_dataset)}")
|
||||
return lm_dataset
|
||||
# tokenize and chunk dataset
|
||||
lm_dataset = dataset.map(lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)).map(partial(chunk, chunk_length=2048), batched=True,)
|
||||
|
||||
# Print total number of samples
|
||||
print(f"Total number of samples: {len(lm_dataset)}")
|
||||
return lm_dataset
|
||||
|
||||
@openllm.utils.requires_dependencies("peft", extra="fine-tune")
|
||||
def prepare_for_int4_training(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
gradient_checkpointing: bool = True,
|
||||
bf16: bool = True,
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
def prepare_for_int4_training(model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True,) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
llm = openllm.AutoLLM.for_model(
|
||||
"llama",
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
ensure_available=True,
|
||||
quantize="int4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
use_cache=not gradient_checkpointing,
|
||||
device_map="auto",
|
||||
)
|
||||
print("Model summary:", llm.model)
|
||||
llm = openllm.AutoLLM.for_model("llama", model_id=model_id, model_version=model_version, ensure_available=True, quantize="int4", bnb_4bit_compute_dtype=torch.bfloat16, use_cache=not gradient_checkpointing, device_map="auto",)
|
||||
print("Model summary:", llm.model)
|
||||
|
||||
# get lora target modules
|
||||
modules = find_all_linear_names(llm.model)
|
||||
print(f"Found {len(modules)} modules to quantize: {modules}")
|
||||
# get lora target modules
|
||||
modules = find_all_linear_names(llm.model)
|
||||
print(f"Found {len(modules)} modules to quantize: {modules}")
|
||||
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules
|
||||
)
|
||||
|
||||
# pre-process the model by upcasting the layer norms in float 32 for
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
if bf16:
|
||||
module = module.to(torch.bfloat16)
|
||||
if "norm" in name:
|
||||
module = module.to(torch.float32)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
if bf16 and module.weight.dtype == torch.float32:
|
||||
module = module.to(torch.bfloat16)
|
||||
return model, tokenizer
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
|
||||
|
||||
# pre-process the model by upcasting the layer norms in float 32 for
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
if bf16:
|
||||
module = module.to(torch.bfloat16)
|
||||
if "norm" in name:
|
||||
module = module.to(torch.float32)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
if bf16 and module.weight.dtype == torch.float32:
|
||||
module = module.to(torch.bfloat16)
|
||||
return model, tokenizer
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=1)
|
||||
gradient_checkpointing: bool = dataclasses.field(default=True)
|
||||
bf16: bool = dataclasses.field(default=torch.cuda.get_device_capability()[0] == 8)
|
||||
learning_rate: float = dataclasses.field(default=5e-5)
|
||||
num_train_epochs: int = dataclasses.field(default=3)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
report_to: str = dataclasses.field(default="none")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "llama"))
|
||||
save_strategy: str = dataclasses.field(default="no")
|
||||
|
||||
per_device_train_batch_size: int = dataclasses.field(default=1)
|
||||
gradient_checkpointing: bool = dataclasses.field(default=True)
|
||||
bf16: bool = dataclasses.field(default=torch.cuda.get_device_capability()[0] == 8)
|
||||
learning_rate: float = dataclasses.field(default=5e-5)
|
||||
num_train_epochs: int = dataclasses.field(default=3)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
report_to: str = dataclasses.field(default="none")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "llama"))
|
||||
save_strategy: str = dataclasses.field(default="no")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION)
|
||||
seed: int = dataclasses.field(default=42)
|
||||
merge_weights: bool = dataclasses.field(default=False)
|
||||
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION)
|
||||
seed: int = dataclasses.field(default=42)
|
||||
merge_weights: bool = dataclasses.field(default=False)
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
model_args, training_rags = ModelArguments(), TrainingArguments()
|
||||
model_args, training_rags = ModelArguments(), TrainingArguments()
|
||||
else:
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(
|
||||
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
|
||||
)
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
# import the model first hand
|
||||
openllm.import_model("llama", model_id=model_args.model_id, model_version=model_args.model_version)
|
||||
|
||||
|
||||
def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
import peft
|
||||
import peft
|
||||
|
||||
transformers.set_seed(model_args.seed)
|
||||
transformers.set_seed(model_args.seed)
|
||||
|
||||
model, tokenizer = prepare_for_int4_training(
|
||||
model_args.model_id,
|
||||
gradient_checkpointing=training_args.gradient_checkpointing,
|
||||
bf16=training_args.bf16,
|
||||
)
|
||||
datasets = prepare_datasets(tokenizer)
|
||||
model, tokenizer = prepare_for_int4_training(model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16,)
|
||||
datasets = prepare_datasets(tokenizer)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator,
|
||||
)
|
||||
trainer = transformers.Trainer(model=model, args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)), train_dataset=datasets, data_collator=transformers.default_data_collator,)
|
||||
|
||||
trainer.train()
|
||||
trainer.train()
|
||||
|
||||
if model_args.merge_weights:
|
||||
# note that this will requires larger GPU as we will load the whole model into memory
|
||||
if model_args.merge_weights:
|
||||
# note that this will requires larger GPU as we will load the whole model into memory
|
||||
|
||||
# merge adapter weights with base model and save
|
||||
# save int4 model
|
||||
trainer.model.save_pretrained(training_args.output_dir, safe_serialization=False)
|
||||
# merge adapter weights with base model and save
|
||||
# save int4 model
|
||||
trainer.model.save_pretrained(training_args.output_dir, safe_serialization=False)
|
||||
|
||||
# gc mem
|
||||
del model, trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(
|
||||
training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16
|
||||
)
|
||||
# merge lora with base weights and save
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained(
|
||||
os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB"
|
||||
)
|
||||
else:
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
# gc mem
|
||||
del model, trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
||||
# merge lora with base weights and save
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained(os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB")
|
||||
else:
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
|
||||
train_loop(model_args, training_args)
|
||||
|
||||
@@ -20,71 +20,38 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from peft import PeftModel
|
||||
from peft import PeftModel
|
||||
|
||||
DEFAULT_MODEL_ID = "facebook/opt-6.7b"
|
||||
|
||||
|
||||
def load_trainer(
|
||||
model: PeftModel,
|
||||
tokenizer: transformers.GPT2TokenizerFast,
|
||||
dataset_dict: t.Any,
|
||||
training_args: TrainingArguments,
|
||||
):
|
||||
return transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=dataset_dict["train"],
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir),
|
||||
**dataclasses.asdict(training_args),
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
|
||||
def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments,):
|
||||
return transformers.Trainer(model=model, train_dataset=dataset_dict["train"], args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args),), data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=50)
|
||||
learning_rate: float = dataclasses.field(default=3e-4)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "opt"))
|
||||
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=50)
|
||||
learning_rate: float = dataclasses.field(default=3e-4)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "opt"))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(
|
||||
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
|
||||
)
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
|
||||
model, tokenizer = openllm.AutoLLM.for_model(
|
||||
"opt",
|
||||
model_id=model_args.model_id,
|
||||
quantize="int8",
|
||||
ensure_available=True,
|
||||
).prepare_for_training(
|
||||
adapter_type="lora",
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
)
|
||||
model, tokenizer = openllm.AutoLLM.for_model("opt", model_id=model_args.model_id, quantize="int8", ensure_available=True,).prepare_for_training(adapter_type="lora", r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none",)
|
||||
|
||||
# ft on english_quotes
|
||||
data = load_dataset("Abirate/english_quotes")
|
||||
|
||||
@@ -49,104 +49,80 @@ from ..utils import LazyLoader
|
||||
from ..utils import LazyModule
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import transformers
|
||||
import bentoml
|
||||
import transformers
|
||||
|
||||
from .._llm import M
|
||||
from .._llm import T
|
||||
from .._llm import M
|
||||
from .._llm import T
|
||||
else:
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model:
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs)
|
||||
elif llm.runtime == "ggml":
|
||||
return openllm.ggml.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs)
|
||||
else:
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs)
|
||||
elif llm.runtime == "ggml":
|
||||
return openllm.ggml.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs)
|
||||
else:
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.get(llm, auto_import=auto_import)
|
||||
elif llm.runtime == "ggml":
|
||||
return openllm.ggml.get(llm, auto_import=auto_import)
|
||||
else:
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.get(llm, auto_import=auto_import)
|
||||
elif llm.runtime == "ggml":
|
||||
return openllm.ggml.get(llm, auto_import=auto_import)
|
||||
else:
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
def save_pretrained(llm: openllm.LLM[M, T], save_directory: str, **attrs: t.Any) -> None:
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.save_pretrained(llm, save_directory, **attrs)
|
||||
elif llm.runtime == "ggml":
|
||||
return openllm.ggml.save_pretrained(llm, save_directory, **attrs)
|
||||
else:
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.save_pretrained(llm, save_directory, **attrs)
|
||||
elif llm.runtime == "ggml":
|
||||
return openllm.ggml.save_pretrained(llm, save_directory, **attrs)
|
||||
else:
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.load_model(llm, *decls, **attrs)
|
||||
elif llm.runtime == "ggml":
|
||||
return openllm.ggml.load_model(llm, *decls, **attrs)
|
||||
else:
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.load_model(llm, *decls, **attrs)
|
||||
elif llm.runtime == "ggml":
|
||||
return openllm.ggml.load_model(llm, *decls, **attrs)
|
||||
else:
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
from .transformers import infer_tokenizers_class_for_llm
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
from .transformers import infer_tokenizers_class_for_llm
|
||||
|
||||
bentomodel_fs = llm._bentomodel._fs
|
||||
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile:
|
||||
try:
|
||||
tokenizer = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"]
|
||||
except KeyError:
|
||||
# This could happen if users implement their own import_model
|
||||
raise OpenLLMException(
|
||||
"Model does not have tokenizer. Make sure to save \
|
||||
bentomodel_fs = llm._bentomodel._fs
|
||||
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile:
|
||||
try:
|
||||
tokenizer = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"]
|
||||
except KeyError:
|
||||
# This could happen if users implement their own import_model
|
||||
raise OpenLLMException("Model does not have tokenizer. Make sure to save \
|
||||
the tokenizer within the model via 'custom_objects'.\
|
||||
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))"
|
||||
) from None
|
||||
else:
|
||||
tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(
|
||||
bentomodel_fs.getsyspath("/"),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
**tokenizer_attrs,
|
||||
)
|
||||
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))") from None
|
||||
else:
|
||||
tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(bentomodel_fs.getsyspath("/"), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs,)
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return tokenizer
|
||||
return tokenizer
|
||||
|
||||
|
||||
_extras = {
|
||||
"get": get,
|
||||
"import_model": import_model,
|
||||
"save_pretrained": save_pretrained,
|
||||
"load_model": load_model,
|
||||
"load_tokenizer": load_tokenizer,
|
||||
}
|
||||
_extras = {"get": get, "import_model": import_model, "save_pretrained": save_pretrained, "load_model": load_model, "load_tokenizer": load_tokenizer,}
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": []}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from . import ggml as ggml
|
||||
from . import transformers as transformers
|
||||
from . import ggml as ggml
|
||||
from . import transformers as transformers
|
||||
else:
|
||||
import sys
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects=_extras,
|
||||
)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects=_extras,)
|
||||
|
||||
@@ -17,21 +17,8 @@ from __future__ import annotations
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
|
||||
# NOTE: vllm will use PyTorch implementation of transformers for serialisation
|
||||
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
|
||||
"tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"),
|
||||
"flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
|
||||
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
|
||||
}
|
||||
|
||||
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
HUB_ATTRS = [
|
||||
"cache_dir",
|
||||
"code_revision",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"subfolder",
|
||||
"use_auth_token",
|
||||
]
|
||||
HUB_ATTRS = ["cache_dir", "code_revision", "force_download", "local_files_only", "proxies", "resume_download", "revision", "subfolder", "use_auth_token",]
|
||||
|
||||
@@ -23,54 +23,42 @@ import bentoml
|
||||
from ..exceptions import OpenLLMException
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
import openllm
|
||||
|
||||
from .._llm import M
|
||||
from .._llm import M
|
||||
|
||||
_conversion_strategy = {"pt": "ggml"}
|
||||
|
||||
def import_model(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
*decls: t.Any,
|
||||
trust_remote_code: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Model:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
|
||||
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
|
||||
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if model.info.module not in ("openllm.serialisation.ggml", __name__):
|
||||
raise bentoml.exceptions.NotFound(
|
||||
f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'."
|
||||
)
|
||||
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
|
||||
raise OpenLLMException(
|
||||
f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}."
|
||||
)
|
||||
return model
|
||||
except bentoml.exceptions.NotFound:
|
||||
if auto_import:
|
||||
return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
|
||||
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if model.info.module not in ("openllm.serialisation.ggml", __name__):
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
|
||||
raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
return model
|
||||
except bentoml.exceptions.NotFound:
|
||||
if auto_import:
|
||||
return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise
|
||||
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
"""Load the model from BentoML store.
|
||||
|
||||
By default, it will try to find check the model in the local store.
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
"""Load the model from BentoML store.
|
||||
|
||||
By default, it will try to find check the model in the local store.
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
|
||||
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
|
||||
@@ -38,237 +38,183 @@ from ..utils import lenient_issubclass
|
||||
from ..utils import normalize_attrs_to_model_tokenizer_pair
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import auto_gptq as autogptq
|
||||
import torch
|
||||
import torch.cuda
|
||||
import vllm
|
||||
import auto_gptq as autogptq
|
||||
import torch
|
||||
import torch.cuda
|
||||
import vllm
|
||||
|
||||
import openllm
|
||||
import transformers as _transformers
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
import openllm
|
||||
import transformers as _transformers
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from .._llm import M
|
||||
from .._llm import T
|
||||
from .._types import DictStrAny
|
||||
from .._llm import M
|
||||
from .._llm import T
|
||||
from .._types import DictStrAny
|
||||
else:
|
||||
vllm = LazyLoader("vllm", globals(), "vllm")
|
||||
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
_transformers = LazyLoader("_transformers", globals(), "transformers")
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
torch.cuda = LazyLoader("torch.cuda", globals(), "torch.cuda")
|
||||
vllm = LazyLoader("vllm", globals(), "vllm")
|
||||
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
_transformers = LazyLoader("_transformers", globals(), "transformers")
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
torch.cuda = LazyLoader("torch.cuda", globals(), "torch.cuda")
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def process_transformers_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[_transformers.PretrainedConfig, dict[str, t.Any], dict[str, t.Any]]:
|
||||
"""Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs."""
|
||||
config = attrs.pop("config", None)
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
if not isinstance(config, _transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype")
|
||||
config, attrs = _transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
return config, hub_attrs, attrs
|
||||
"""Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs."""
|
||||
config = attrs.pop("config", None)
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
if not isinstance(config, _transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype")
|
||||
config, attrs = _transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
def infer_tokenizers_class_for_llm(__llm: openllm.LLM[t.Any, T]) -> T:
|
||||
tokenizer_class = __llm.config["tokenizer_class"]
|
||||
if tokenizer_class is None: tokenizer_class = "AutoTokenizer"
|
||||
__cls = getattr(_transformers, tokenizer_class)
|
||||
if __cls is None: raise ValueError(f"{tokenizer_class} is not a valid Tokenizer class from 'transformers.' Set '{__llm}.__config__[\"trust_remote_code\"] = True' and try again.")
|
||||
return __cls
|
||||
tokenizer_class = __llm.config["tokenizer_class"]
|
||||
if tokenizer_class is None: tokenizer_class = "AutoTokenizer"
|
||||
__cls = getattr(_transformers, tokenizer_class)
|
||||
if __cls is None: raise ValueError(f"{tokenizer_class} is not a valid Tokenizer class from 'transformers.' Set '{__llm}.__config__[\"trust_remote_code\"] = True' and try again.")
|
||||
return __cls
|
||||
|
||||
def infer_autoclass_from_llm_config(llm: openllm.LLM[M, T], config: _transformers.PretrainedConfig) -> _BaseAutoModelClass:
|
||||
if llm.config["trust_remote_code"]:
|
||||
autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM"
|
||||
if not hasattr(config, "auto_map"): raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping")
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = "AutoModel"
|
||||
return getattr(_transformers, autoclass)
|
||||
else:
|
||||
if type(config) in _transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
|
||||
elif type(config) in _transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
|
||||
else: raise OpenLLMException(f"Model type {type(config)} is not supported yet.")
|
||||
return getattr(_transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx])
|
||||
if llm.config["trust_remote_code"]:
|
||||
autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM"
|
||||
if not hasattr(config, "auto_map"): raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping")
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = "AutoModel"
|
||||
return getattr(_transformers, autoclass)
|
||||
else:
|
||||
if type(config) in _transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
|
||||
elif type(config) in _transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
|
||||
else: raise OpenLLMException(f"Model type {type(config)} is not supported yet.")
|
||||
return getattr(_transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx])
|
||||
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model:
|
||||
"""Auto detect model type from given model_id and import it to bentoml's model store.
|
||||
"""Auto detect model type from given model_id and import it to bentoml's model store.
|
||||
|
||||
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
|
||||
returning all of the unused kwargs.
|
||||
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
|
||||
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
|
||||
returning all of the unused kwargs.
|
||||
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
|
||||
|
||||
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
|
||||
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
|
||||
|
||||
Refer to Transformers documentation for more information about kwargs.
|
||||
Refer to Transformers documentation for more information about kwargs.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance for this given model.
|
||||
trust_remote_code: Whether to trust the remote code when loading the model.
|
||||
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
"""
|
||||
config, hub_attrs, attrs = process_transformers_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize_method = llm._quantize_method
|
||||
safe_serialisation = first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors")
|
||||
# Disable safe serialization with vLLM
|
||||
if llm.__llm_implementation__ == "vllm": safe_serialisation = False
|
||||
metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method if quantize_method is not None else False}
|
||||
signatures: DictStrAny = {}
|
||||
if quantize_method == "gptq":
|
||||
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm.model_id,
|
||||
*decls,
|
||||
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_safetensors=safe_serialisation,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework})
|
||||
signatures["generate"] = {"batchable": False}
|
||||
else:
|
||||
# this model might be called with --quantize int4, therefore we need to pop this out
|
||||
# since saving int4 is not yet supported
|
||||
if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False): attrs.pop("quantization_config")
|
||||
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(
|
||||
llm.model_id,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_safetensors=safe_serialisation,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.framework})
|
||||
Args:
|
||||
llm: The LLM instance for this given model.
|
||||
trust_remote_code: Whether to trust the remote code when loading the model.
|
||||
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
"""
|
||||
config, hub_attrs, attrs = process_transformers_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize_method = llm._quantize_method
|
||||
safe_serialisation = first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors")
|
||||
# Disable safe serialization with vLLM
|
||||
if llm.__llm_implementation__ == "vllm": safe_serialisation = False
|
||||
metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method if quantize_method is not None else False}
|
||||
signatures: DictStrAny = {}
|
||||
if quantize_method == "gptq":
|
||||
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(llm.model_id, *decls, quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), trust_remote_code=trust_remote_code, use_safetensors=safe_serialisation, **hub_attrs, **attrs,)
|
||||
metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework})
|
||||
signatures["generate"] = {"batchable": False}
|
||||
else:
|
||||
# this model might be called with --quantize int4, therefore we need to pop this out
|
||||
# since saving int4 is not yet supported
|
||||
if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False): attrs.pop("quantization_config")
|
||||
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, use_safetensors=safe_serialisation, **hub_attrs, **attrs,)
|
||||
metadata.update({"_pretrained_class": model.__class__.__name__, "_framework": model.framework})
|
||||
|
||||
_tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(
|
||||
llm.model_id,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**tokenizer_attrs,
|
||||
)
|
||||
if _tokenizer.pad_token is None: _tokenizer.pad_token = _tokenizer.eos_token
|
||||
_tokenizer = infer_tokenizers_class_for_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs,)
|
||||
if _tokenizer.pad_token is None: _tokenizer.pad_token = _tokenizer.eos_token
|
||||
|
||||
# NOTE: quick hack to set the loaded into llm object to use with save_pretrained
|
||||
# to avoid recursive call when the model is not yet available in local store
|
||||
_object_setattr(llm, "__llm_model__", model)
|
||||
_object_setattr(llm, "__llm_tokenizer__", _tokenizer)
|
||||
create_kwargs: DictStrAny= dict(
|
||||
module="openllm.serialisation.transformers", api_version="v1", context=generate_context(framework_name="openllm"), labels=generate_labels(llm), metadata=metadata,
|
||||
signatures=signatures if signatures else make_default_signatures(model), options=ModelOptions(),
|
||||
external_modules=[importlib.import_module(model.__module__), importlib.import_module(_tokenizer.__module__)] if trust_remote_code else None,
|
||||
)
|
||||
if "use_tempfs" in inspect.signature(bentoml.models.create).parameters: create_kwargs["use_tempfs"] = False
|
||||
# NOTE: quick hack to set the loaded into llm object to use with save_pretrained
|
||||
# to avoid recursive call when the model is not yet available in local store
|
||||
_object_setattr(llm, "__llm_model__", model)
|
||||
_object_setattr(llm, "__llm_tokenizer__", _tokenizer)
|
||||
create_kwargs: DictStrAny = dict(
|
||||
module="openllm.serialisation.transformers", api_version="v1", context=generate_context(framework_name="openllm"), labels=generate_labels(llm), metadata=metadata, signatures=signatures if signatures else make_default_signatures(model), options=ModelOptions(), external_modules=[importlib.import_module(model.__module__),
|
||||
importlib.import_module(_tokenizer.__module__)] if trust_remote_code else None,
|
||||
)
|
||||
if "use_tempfs" in inspect.signature(bentoml.models.create).parameters: create_kwargs["use_tempfs"] = False
|
||||
|
||||
try:
|
||||
with bentoml.models.create(llm.tag, **create_kwargs) as bentomodel:
|
||||
save_pretrained(llm, bentomodel.path, safe_serialization=safe_serialisation)
|
||||
return bentomodel
|
||||
finally:
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
# in the case where users first run openllm start without the model
|
||||
# available locally.
|
||||
if is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
try:
|
||||
with bentoml.models.create(llm.tag, **create_kwargs) as bentomodel:
|
||||
save_pretrained(llm, bentomodel.path, safe_serialization=safe_serialisation)
|
||||
return bentomodel
|
||||
finally:
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
# in the case where users first run openllm start without the model
|
||||
# available locally.
|
||||
if is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
|
||||
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
|
||||
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
# compat with bentoml.transformers.get
|
||||
if model.info.module not in ("openllm.serialisation.transformers", __name__):
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
|
||||
raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
return model
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise err from None
|
||||
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
# compat with bentoml.transformers.get
|
||||
if model.info.module not in ("openllm.serialisation.transformers", __name__):
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
|
||||
raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
return model
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise err from None
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
"""Load the model from BentoML store.
|
||||
"""Load the model from BentoML store.
|
||||
|
||||
By default, it will try to find check the model in the local store.
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
config, hub_attrs, attrs = process_transformers_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs)
|
||||
safe_serialization = first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors")
|
||||
if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq":
|
||||
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm._bentomodel.path,
|
||||
*decls,
|
||||
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
use_safetensors=safe_serialization,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
By default, it will try to find check the model in the local store.
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
config, hub_attrs, attrs = process_transformers_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs)
|
||||
safe_serialization = first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors")
|
||||
if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq":
|
||||
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(llm._bentomodel.path, *decls, quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), trust_remote_code=llm.__llm_trust_remote_code__, use_safetensors=safe_serialization, **hub_attrs, **attrs,)
|
||||
|
||||
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(
|
||||
llm._bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
# NOTE: we only cast and load the model if it is not already quantized and setup correctly
|
||||
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_quantized", False)
|
||||
if torch.cuda.is_available() and device_count() == 1 and not loaded_in_kbit:
|
||||
try: model = model.to("cuda")
|
||||
except torch.cuda.OutOfMemoryError as err: raise RuntimeError(f"Failed to convert {llm.config['model_name']} with model_id '{llm.model_id}' to CUDA.\nNote: You can try out '--quantize int8 | int4' for dynamic quantization.") from err
|
||||
# BetterTransformer is currently only supported on PyTorch.
|
||||
if llm.bettertransformer and isinstance(model, _transformers.PreTrainedModel): model = model.to_bettertransformer()
|
||||
return t.cast("M", model)
|
||||
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, **hub_attrs, **attrs,)
|
||||
# NOTE: we only cast and load the model if it is not already quantized and setup correctly
|
||||
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_quantized", False)
|
||||
if torch.cuda.is_available() and device_count() == 1 and not loaded_in_kbit:
|
||||
try:
|
||||
model = model.to("cuda")
|
||||
except torch.cuda.OutOfMemoryError as err:
|
||||
raise RuntimeError(f"Failed to convert {llm.config['model_name']} with model_id '{llm.model_id}' to CUDA.\nNote: You can try out '--quantize int8 | int4' for dynamic quantization.") from err
|
||||
# BetterTransformer is currently only supported on PyTorch.
|
||||
if llm.bettertransformer and isinstance(model, _transformers.PreTrainedModel): model = model.to_bettertransformer()
|
||||
return t.cast("M", model)
|
||||
|
||||
def save_pretrained(
|
||||
llm: openllm.LLM[M, T],
|
||||
save_directory: str,
|
||||
is_main_process: bool = True,
|
||||
state_dict: DictStrAny | None = None,
|
||||
save_function: t.Callable[..., None] | None = None,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: int | str = "10GB",
|
||||
safe_serialization: bool = False,
|
||||
variant: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> None:
|
||||
"""Light wrapper around ``transformers.PreTrainedTokenizer.save_pretrained`` and ``transformers.PreTrainedModel.save_pretrained``."""
|
||||
save_function = first_not_none(save_function, default=torch.save)
|
||||
model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
safe_serialization = safe_serialization or llm._serialisation_format == "safetensors"
|
||||
# NOTE: disable safetensors for vllm
|
||||
if llm.__llm_implementation__ == "vllm": safe_serialization = False
|
||||
if llm._quantize_method == "gptq":
|
||||
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if not lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})")
|
||||
t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
|
||||
elif LazyType["vllm.LLMEngine"]("vllm.LLMEngine").isinstance(llm.model): raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.")
|
||||
elif isinstance(llm.model, _transformers.Pipeline): llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
|
||||
else:
|
||||
# We can safely cast here since it will be the PreTrainedModel protocol.
|
||||
t.cast("_transformers.PreTrainedModel", llm.model).save_pretrained(
|
||||
save_directory,
|
||||
is_main_process=is_main_process,
|
||||
state_dict=state_dict,
|
||||
save_function=save_function,
|
||||
push_to_hub=push_to_hub,
|
||||
max_shard_size=max_shard_size,
|
||||
safe_serialization=safe_serialization,
|
||||
variant=variant,
|
||||
**model_save_attrs
|
||||
)
|
||||
llm.tokenizer.save_pretrained(save_directory, push_to_hub=push_to_hub, **tokenizer_save_attrs)
|
||||
def save_pretrained(llm: openllm.LLM[M, T], save_directory: str, is_main_process: bool = True, state_dict: DictStrAny | None = None, save_function: t.Callable[..., None] | None = None, push_to_hub: bool = False, max_shard_size: int | str = "10GB", safe_serialization: bool = False, variant: str | None = None, **attrs: t.Any,) -> None:
|
||||
"""Light wrapper around ``transformers.PreTrainedTokenizer.save_pretrained`` and ``transformers.PreTrainedModel.save_pretrained``."""
|
||||
save_function = first_not_none(save_function, default=torch.save)
|
||||
model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
safe_serialization = safe_serialization or llm._serialisation_format == "safetensors"
|
||||
# NOTE: disable safetensors for vllm
|
||||
if llm.__llm_implementation__ == "vllm": safe_serialization = False
|
||||
if llm._quantize_method == "gptq":
|
||||
if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if not lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})")
|
||||
t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
|
||||
elif LazyType["vllm.LLMEngine"]("vllm.LLMEngine").isinstance(llm.model):
|
||||
raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.")
|
||||
elif isinstance(llm.model, _transformers.Pipeline):
|
||||
llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
|
||||
else:
|
||||
# We can safely cast here since it will be the PreTrainedModel protocol.
|
||||
t.cast("_transformers.PreTrainedModel", llm.model).save_pretrained(save_directory, is_main_process=is_main_process, state_dict=state_dict, save_function=save_function, push_to_hub=push_to_hub, max_shard_size=max_shard_size, safe_serialization=safe_serialization, variant=variant, **model_save_attrs)
|
||||
llm.tokenizer.save_pretrained(save_directory, push_to_hub=push_to_hub, **tokenizer_save_attrs)
|
||||
|
||||
@@ -26,77 +26,50 @@ import openllm
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ._types import LiteralRuntime
|
||||
|
||||
from ._types import LiteralRuntime
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_bento(
|
||||
model: str,
|
||||
model_id: str | None = None,
|
||||
quantize: t.Literal["int4", "int8", "gptq"] | None = None,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
cleanup: bool = False,
|
||||
) -> t.Iterator[bentoml.Bento]:
|
||||
logger.info("Building BentoML for %s", model)
|
||||
bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime)
|
||||
yield bento
|
||||
def build_bento(model: str, model_id: str | None = None, quantize: t.Literal["int4", "int8", "gptq"] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", cleanup: bool = False,) -> t.Iterator[bentoml.Bento]:
|
||||
logger.info("Building BentoML for %s", model)
|
||||
bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime)
|
||||
yield bento
|
||||
if cleanup:
|
||||
logger.info("Deleting %s", bento.tag)
|
||||
bentoml.bentos.delete(bento.tag)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any,) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
|
||||
else: bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
if image_tag is None: image_tag = str(bento_tag)
|
||||
|
||||
executable = shutil.which("docker")
|
||||
if not executable:
|
||||
raise RuntimeError("docker executable not found")
|
||||
|
||||
try:
|
||||
logger.info("Building container for %s", bento_tag)
|
||||
bentoml.container.build(bento_tag, backend="docker", image_tag=(image_tag,), progress="plain", **attrs,)
|
||||
yield image_tag
|
||||
finally:
|
||||
if cleanup:
|
||||
logger.info("Deleting %s", bento.tag)
|
||||
bentoml.bentos.delete(bento.tag)
|
||||
|
||||
logger.info("Deleting container %s", image_tag)
|
||||
subprocess.check_output([executable, "rmi", "-f", image_tag])
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_container(
|
||||
bento: bentoml.Bento | str | bentoml.Tag,
|
||||
image_tag: str | None = None,
|
||||
cleanup: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
|
||||
else: bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
if image_tag is None: image_tag = str(bento_tag)
|
||||
def prepare(model: str, model_id: str | None = None, implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, cleanup: bool = True,) -> t.Iterator[str]:
|
||||
if clean_context is None:
|
||||
clean_context = contextlib.ExitStack()
|
||||
cleanup = True
|
||||
|
||||
executable = shutil.which("docker")
|
||||
if not executable:
|
||||
raise RuntimeError("docker executable not found")
|
||||
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
|
||||
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
|
||||
|
||||
try:
|
||||
logger.info("Building container for %s", bento_tag)
|
||||
bentoml.container.build(
|
||||
bento_tag,
|
||||
backend="docker",
|
||||
image_tag=(image_tag,),
|
||||
progress="plain",
|
||||
**attrs,
|
||||
)
|
||||
yield image_tag
|
||||
finally:
|
||||
if cleanup:
|
||||
logger.info("Deleting container %s", image_tag)
|
||||
subprocess.check_output([executable, "rmi", "-f", image_tag])
|
||||
if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
|
||||
else: bento = bentoml.get(bento_tag)
|
||||
|
||||
container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prepare(
|
||||
model: str,
|
||||
model_id: str | None = None,
|
||||
implementation: LiteralRuntime = "pt",
|
||||
deployment_mode: t.Literal["container", "local"] = "local",
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = True,
|
||||
) -> t.Iterator[str]:
|
||||
if clean_context is None:
|
||||
clean_context = contextlib.ExitStack()
|
||||
cleanup = True
|
||||
|
||||
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
|
||||
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
|
||||
|
||||
if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
|
||||
else: bento = bentoml.get(bento_tag)
|
||||
|
||||
container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_")
|
||||
|
||||
if deployment_mode == "container": container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
|
||||
yield container_name
|
||||
if cleanup: clean_context.close()
|
||||
if deployment_mode == "container": container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
|
||||
yield container_name
|
||||
if cleanup: clean_context.close()
|
||||
|
||||
@@ -50,74 +50,77 @@ from .lazy import LazyModule
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from typing import GenericAlias as _TypingGenericAlias # type: ignore
|
||||
from typing import GenericAlias as _TypingGenericAlias # type: ignore
|
||||
except ImportError:
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
_TypingGenericAlias = () # type: ignore
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
_TypingGenericAlias = () # type: ignore
|
||||
|
||||
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
|
||||
else:
|
||||
# _GenericAlias is the actual GenericAlias implementation
|
||||
_WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore
|
||||
# _GenericAlias is the actual GenericAlias implementation
|
||||
_WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload as _overload
|
||||
from typing import overload as _overload
|
||||
else:
|
||||
from typing_extensions import overload as _overload
|
||||
from typing_extensions import overload as _overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
import openllm
|
||||
|
||||
from .._types import AnyCallable
|
||||
from .._types import DictStrAny
|
||||
from .._types import LiteralRuntime
|
||||
from .._types import AnyCallable
|
||||
from .._types import DictStrAny
|
||||
from .._types import LiteralRuntime
|
||||
|
||||
DEV_DEBUG_VAR = "OPENLLMDEVDEBUG"
|
||||
|
||||
def set_debug_mode(enabled: bool, level: int = 1) -> None:
|
||||
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
|
||||
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
os.environ[DEBUG_ENV_VAR] = str(enabled)
|
||||
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
|
||||
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
|
||||
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
os.environ[DEBUG_ENV_VAR] = str(enabled)
|
||||
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
|
||||
|
||||
def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool:
|
||||
try: return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
if isinstance(cls, _WithArgsTypes): return False
|
||||
raise
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
if isinstance(cls, _WithArgsTypes): return False
|
||||
raise
|
||||
|
||||
def available_devices() -> tuple[str, ...]:
|
||||
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
|
||||
from .._strategies import NvidiaGpuResource
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
|
||||
from .._strategies import NvidiaGpuResource
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def generate_hash_from_file(f: str, algorithm: t.Literal["md5", "sha1"] = "sha1") -> str:
|
||||
"""Generate a hash from given file's modification time.
|
||||
"""Generate a hash from given file's modification time.
|
||||
|
||||
Args:
|
||||
f: The file to generate the hash from.
|
||||
algorithm: The hashing algorithm to use. Defaults to 'sha1' (similar to how Git generate its commit hash.)
|
||||
Args:
|
||||
f: The file to generate the hash from.
|
||||
algorithm: The hashing algorithm to use. Defaults to 'sha1' (similar to how Git generate its commit hash.)
|
||||
|
||||
Returns:
|
||||
The generated hash.
|
||||
"""
|
||||
return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()
|
||||
Returns:
|
||||
The generated hash.
|
||||
"""
|
||||
return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def device_count() -> int: return len(available_devices())
|
||||
def device_count() -> int:
|
||||
return len(available_devices())
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None:
|
||||
"""This makes sure that we don't overwrite any existing attributes on the object."""
|
||||
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
|
||||
if not hasattr(obj, name): _setattr(name, value)
|
||||
"""This makes sure that we don't overwrite any existing attributes on the object."""
|
||||
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
|
||||
if not hasattr(obj, name): _setattr(name, value)
|
||||
|
||||
def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str: return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
|
||||
def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str:
|
||||
return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
|
||||
|
||||
# Special debug flag controled via OPENLLMDEVDEBUG
|
||||
DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.getenv(DEV_DEBUG_VAR)))
|
||||
@@ -125,210 +128,201 @@ DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.gete
|
||||
MYPY = False
|
||||
SHOW_CODEGEN = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3
|
||||
|
||||
def get_debug_mode() -> bool: return DEBUG or _get_debug_mode()
|
||||
def get_quiet_mode() -> bool: return not DEBUG and _get_quiet_mode()
|
||||
def get_debug_mode() -> bool:
|
||||
return DEBUG or _get_debug_mode()
|
||||
|
||||
def get_quiet_mode() -> bool:
|
||||
return not DEBUG and _get_quiet_mode()
|
||||
|
||||
class ExceptionFilter(logging.Filter):
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
"""A filter of all exception."""
|
||||
if exclude_exceptions is None: exclude_exceptions = [ConflictError]
|
||||
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
|
||||
super(ExceptionFilter, self).__init__(**kwargs)
|
||||
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.exc_info:
|
||||
etype, _, _ = record.exc_info
|
||||
if etype is not None:
|
||||
for exc in self.EXCLUDE_EXCEPTIONS:
|
||||
if issubclass(etype, exc): return False
|
||||
return True
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
"""A filter of all exception."""
|
||||
if exclude_exceptions is None: exclude_exceptions = [ConflictError]
|
||||
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
|
||||
super(ExceptionFilter, self).__init__(**kwargs)
|
||||
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.exc_info:
|
||||
etype, _, _ = record.exc_info
|
||||
if etype is not None:
|
||||
for exc in self.EXCLUDE_EXCEPTIONS:
|
||||
if issubclass(etype, exc): return False
|
||||
return True
|
||||
|
||||
class InfoFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return logging.INFO <= record.levelno < logging.WARNING
|
||||
|
||||
_LOGGING_CONFIG: DictStrAny = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": True,
|
||||
"filters": {"excfilter": {"()": "openllm.utils.ExceptionFilter"}, "infofilter": {"()": "openllm.utils.InfoFilter"}},
|
||||
"handlers": {
|
||||
"bentomlhandler": {
|
||||
"class": "logging.StreamHandler",
|
||||
"filters": ["excfilter", "infofilter"],
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"defaulthandler": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": logging.WARNING,
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"bentoml": {
|
||||
"handlers": ["bentomlhandler", "defaulthandler"],
|
||||
"level": logging.INFO,
|
||||
"propagate": False,
|
||||
},
|
||||
"openllm": {
|
||||
"handlers": ["bentomlhandler", "defaulthandler"],
|
||||
"level": logging.INFO,
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
"root": {"level": logging.WARNING},
|
||||
"version": 1, "disable_existing_loggers": True, "filters": {"excfilter": {"()": "openllm.utils.ExceptionFilter"}, "infofilter": {"()": "openllm.utils.InfoFilter"}}, "handlers": {
|
||||
"bentomlhandler": {"class": "logging.StreamHandler", "filters": ["excfilter", "infofilter"], "stream": "ext://sys.stdout",}, "defaulthandler": {"class": "logging.StreamHandler", "level": logging.WARNING,},
|
||||
}, "loggers": {"bentoml": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False,}, "openllm": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False,},}, "root": {"level": logging.WARNING},
|
||||
}
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure logging for OpenLLM.
|
||||
"""Configure logging for OpenLLM.
|
||||
|
||||
Behaves similar to how BentoML loggers are being configured.
|
||||
"""
|
||||
if get_quiet_mode():
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.ERROR
|
||||
elif get_debug_mode() or DEBUG:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.DEBUG
|
||||
else:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.INFO
|
||||
Behaves similar to how BentoML loggers are being configured.
|
||||
"""
|
||||
if get_quiet_mode():
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.ERROR
|
||||
elif get_debug_mode() or DEBUG:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.DEBUG
|
||||
else:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.INFO
|
||||
|
||||
logging.config.dictConfig(_LOGGING_CONFIG)
|
||||
logging.config.dictConfig(_LOGGING_CONFIG)
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def in_notebook() -> bool:
|
||||
try:
|
||||
from IPython.core.getipython import get_ipython
|
||||
if "IPKernelApp" not in get_ipython().config: return False
|
||||
except ImportError: return False
|
||||
except AttributeError: return False
|
||||
return True
|
||||
try:
|
||||
from IPython.core.getipython import get_ipython
|
||||
if "IPKernelApp" not in get_ipython().config: return False
|
||||
except ImportError:
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
_dockerenv, _cgroup = Path("/.dockerenv"), Path("/proc/self/cgroup")
|
||||
|
||||
class suppress(contextlib.suppress, contextlib.ContextDecorator):
|
||||
"""A version of contextlib.suppress with decorator support.
|
||||
"""A version of contextlib.suppress with decorator support.
|
||||
|
||||
>>> @suppress(KeyError)
|
||||
... def key_error():
|
||||
... {}['']
|
||||
>>> key_error()
|
||||
"""
|
||||
>>> @suppress(KeyError)
|
||||
... def key_error():
|
||||
... {}['']
|
||||
>>> key_error()
|
||||
"""
|
||||
|
||||
def compose(*funcs: AnyCallable) -> AnyCallable:
|
||||
"""Compose any number of unary functions into a single unary function.
|
||||
"""Compose any number of unary functions into a single unary function.
|
||||
|
||||
>>> import textwrap
|
||||
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
|
||||
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
|
||||
>>> strip_and_dedent(compose.__doc__) == expected
|
||||
True
|
||||
>>> import textwrap
|
||||
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
|
||||
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
|
||||
>>> strip_and_dedent(compose.__doc__) == expected
|
||||
True
|
||||
|
||||
Compose also allows the innermost function to take arbitrary arguments.
|
||||
Compose also allows the innermost function to take arbitrary arguments.
|
||||
|
||||
>>> round_three = lambda x: round(x, ndigits=3)
|
||||
>>> f = compose(round_three, int.__truediv__)
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
"""
|
||||
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
return functools.reduce(compose_two, funcs)
|
||||
>>> round_three = lambda x: round(x, ndigits=3)
|
||||
>>> f = compose(round_three, int.__truediv__)
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
"""
|
||||
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable:
|
||||
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
|
||||
return functools.reduce(compose_two, funcs)
|
||||
|
||||
def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
|
||||
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
|
||||
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
|
||||
|
||||
```python
|
||||
@apply(reversed)
|
||||
def get_numbers(start):
|
||||
"doc for get_numbers"
|
||||
return range(start, start+3)
|
||||
list(get_numbers(4))
|
||||
# [6, 5, 4]
|
||||
```
|
||||
```python
|
||||
get_numbers.__doc__
|
||||
# 'doc for get_numbers'
|
||||
```
|
||||
"""
|
||||
return lambda func: functools.wraps(func)(compose(transform, func))
|
||||
```python
|
||||
@apply(reversed)
|
||||
def get_numbers(start):
|
||||
"doc for get_numbers"
|
||||
return range(start, start+3)
|
||||
list(get_numbers(4))
|
||||
# [6, 5, 4]
|
||||
```
|
||||
```python
|
||||
get_numbers.__doc__
|
||||
# 'doc for get_numbers'
|
||||
```
|
||||
"""
|
||||
return lambda func: functools.wraps(func)(compose(transform, func))
|
||||
|
||||
@apply(bool)
|
||||
@suppress(FileNotFoundError)
|
||||
def _text_in_file(text: str, filename: Path) -> bool: return any(text in line for line in filename.open())
|
||||
|
||||
def _text_in_file(text: str, filename: Path) -> bool:
|
||||
return any(text in line for line in filename.open())
|
||||
|
||||
def in_docker() -> bool:
|
||||
"""Is this current environment running in docker?
|
||||
"""Is this current environment running in docker?
|
||||
|
||||
```python
|
||||
type(in_docker())
|
||||
```
|
||||
"""
|
||||
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
|
||||
```python
|
||||
type(in_docker())
|
||||
```
|
||||
"""
|
||||
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
|
||||
|
||||
T, K = t.TypeVar("T"), t.TypeVar("K")
|
||||
|
||||
def resolve_filepath(path: str, ctx: str | None = None) -> str:
|
||||
"""Resolve a file path to an absolute path, expand user and environment variables."""
|
||||
try: return resolve_user_filepath(path, ctx)
|
||||
except FileNotFoundError: return path
|
||||
"""Resolve a file path to an absolute path, expand user and environment variables."""
|
||||
try:
|
||||
return resolve_user_filepath(path, ctx)
|
||||
except FileNotFoundError:
|
||||
return path
|
||||
|
||||
def validate_is_path(maybe_path: str) -> bool: return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
def validate_is_path(maybe_path: str) -> bool:
|
||||
return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
|
||||
def generate_context(framework_name: str) -> _ModelContext:
|
||||
from .import_utils import is_flax_available
|
||||
from .import_utils import is_tf_available
|
||||
from .import_utils import is_torch_available
|
||||
from .import_utils import is_flax_available
|
||||
from .import_utils import is_tf_available
|
||||
from .import_utils import is_torch_available
|
||||
|
||||
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
|
||||
if is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
|
||||
if is_tf_available():
|
||||
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
|
||||
framework_versions["tensorflow"] = get_tf_version()
|
||||
if is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")})
|
||||
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
|
||||
if is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
|
||||
if is_tf_available():
|
||||
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
|
||||
framework_versions["tensorflow"] = get_tf_version()
|
||||
if is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")})
|
||||
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
|
||||
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> DictStrAny:
|
||||
return {
|
||||
"runtime": llm.runtime,
|
||||
"framework": "openllm",
|
||||
"model_name": llm.config["model_name"],
|
||||
"architecture": llm.config["architecture"],
|
||||
"serialisation_format": llm._serialisation_format,
|
||||
}
|
||||
return {"runtime": llm.runtime, "framework": "openllm", "model_name": llm.config["model_name"], "architecture": llm.config["architecture"], "serialisation_format": llm._serialisation_format,}
|
||||
|
||||
_TOKENIZER_PREFIX = "_tokenizer_"
|
||||
|
||||
def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[DictStrAny, DictStrAny]:
|
||||
"""Normalize the given attrs to a model and tokenizer kwargs accordingly."""
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
for k in tuple(attrs.keys()):
|
||||
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
|
||||
return attrs, tokenizer_attrs
|
||||
"""Normalize the given attrs to a model and tokenizer kwargs accordingly."""
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX):]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
for k in tuple(attrs.keys()):
|
||||
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
|
||||
return attrs, tokenizer_attrs
|
||||
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]: ...
|
||||
def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]:
|
||||
...
|
||||
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]: ...
|
||||
def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]:
|
||||
...
|
||||
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]: ...
|
||||
def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]:
|
||||
...
|
||||
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["vllm"]) -> type[openllm.AutoVLLM]: ...
|
||||
def infer_auto_class(implementation: t.Literal["vllm"]) -> type[openllm.AutoVLLM]:
|
||||
...
|
||||
|
||||
def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM] | type[openllm.AutoTFLLM] | type[openllm.AutoFlaxLLM] | type[openllm.AutoVLLM]:
|
||||
if implementation == "tf":
|
||||
from openllm import AutoTFLLM
|
||||
return AutoTFLLM
|
||||
elif implementation == "flax":
|
||||
from openllm import AutoFlaxLLM
|
||||
return AutoFlaxLLM
|
||||
elif implementation == "pt":
|
||||
from openllm import AutoLLM
|
||||
return AutoLLM
|
||||
elif implementation == "vllm":
|
||||
from openllm import AutoVLLM
|
||||
return AutoVLLM
|
||||
else: raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')")
|
||||
|
||||
if implementation == "tf":
|
||||
from openllm import AutoTFLLM
|
||||
return AutoTFLLM
|
||||
elif implementation == "flax":
|
||||
from openllm import AutoFlaxLLM
|
||||
return AutoFlaxLLM
|
||||
elif implementation == "pt":
|
||||
from openllm import AutoLLM
|
||||
return AutoLLM
|
||||
elif implementation == "vllm":
|
||||
from openllm import AutoVLLM
|
||||
return AutoVLLM
|
||||
else:
|
||||
raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')")
|
||||
|
||||
# NOTE: The set marks contains a set of modules name
|
||||
# that are available above and are whitelisted
|
||||
@@ -337,80 +331,52 @@ _whitelist_modules = {"pkg"}
|
||||
|
||||
# XXX: define all classes, functions import above this line
|
||||
# since _extras will be the locals() import from this file.
|
||||
_extras: dict[str, t.Any] = {
|
||||
k: v
|
||||
for k, v in locals().items()
|
||||
if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_"))
|
||||
}
|
||||
_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_"))}
|
||||
|
||||
_extras["__openllm_migration__"] = {"ModelEnv": "EnvVarMixin"}
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"analytics": [],
|
||||
"codegen": [],
|
||||
"dantic": [],
|
||||
"representation": ["ReprMixin"],
|
||||
"lazy": ["LazyModule"],
|
||||
"import_utils": [
|
||||
"OPTIONAL_DEPENDENCIES",
|
||||
"ENV_VARS_TRUE_VALUES",
|
||||
"DummyMetaclass",
|
||||
"EnvVarMixin",
|
||||
"requires_dependencies",
|
||||
"is_cpm_kernels_available",
|
||||
"is_einops_available",
|
||||
"is_flax_available",
|
||||
"is_tf_available",
|
||||
"is_vllm_available",
|
||||
"is_torch_available",
|
||||
"is_bitsandbytes_available",
|
||||
"is_peft_available",
|
||||
"is_datasets_available",
|
||||
"is_transformers_supports_kbit",
|
||||
"is_transformers_supports_agent",
|
||||
"is_jupyter_available",
|
||||
"is_jupytext_available",
|
||||
"is_notebook_available",
|
||||
"is_triton_available",
|
||||
"is_autogptq_available",
|
||||
"require_backends",
|
||||
"analytics": [], "codegen": [], "dantic": [], "representation": ["ReprMixin"], "lazy": ["LazyModule"], "import_utils": [
|
||||
"OPTIONAL_DEPENDENCIES", "ENV_VARS_TRUE_VALUES", "DummyMetaclass", "EnvVarMixin", "requires_dependencies", "is_cpm_kernels_available", "is_einops_available", "is_flax_available", "is_tf_available", "is_vllm_available", "is_torch_available", "is_bitsandbytes_available", "is_peft_available", "is_datasets_available", "is_transformers_supports_kbit",
|
||||
"is_transformers_supports_agent", "is_jupyter_available", "is_jupytext_available", "is_notebook_available", "is_triton_available", "is_autogptq_available", "require_backends",
|
||||
],
|
||||
}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# NOTE: The following exports useful utils from bentoml
|
||||
from . import LazyLoader as LazyLoader
|
||||
from . import LazyType as LazyType
|
||||
from . import analytics as analytics
|
||||
from . import bentoml_cattr as bentoml_cattr
|
||||
from . import codegen as codegen
|
||||
from . import configure_logging as configure_logging
|
||||
from . import dantic as dantic
|
||||
from . import first_not_none as first_not_none
|
||||
from . import reserve_free_port as reserve_free_port
|
||||
from . import set_quiet_mode as set_quiet_mode
|
||||
from . import validate_is_path as validate_is_path
|
||||
from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES
|
||||
from .import_utils import OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES
|
||||
from .import_utils import DummyMetaclass as DummyMetaclass
|
||||
from .import_utils import EnvVarMixin as EnvVarMixin
|
||||
from .import_utils import is_autogptq_available as is_autogptq_available
|
||||
from .import_utils import is_bitsandbytes_available as is_bitsandbytes_available
|
||||
from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available
|
||||
from .import_utils import is_datasets_available as is_datasets_available
|
||||
from .import_utils import is_einops_available as is_einops_available
|
||||
from .import_utils import is_flax_available as is_flax_available
|
||||
from .import_utils import is_jupyter_available as is_jupyter_available
|
||||
from .import_utils import is_jupytext_available as is_jupytext_available
|
||||
from .import_utils import is_notebook_available as is_notebook_available
|
||||
from .import_utils import is_peft_available as is_peft_available
|
||||
from .import_utils import is_tf_available as is_tf_available
|
||||
from .import_utils import is_torch_available as is_torch_available
|
||||
from .import_utils import is_transformers_supports_agent as is_transformers_supports_agent
|
||||
from .import_utils import is_transformers_supports_kbit as is_transformers_supports_kbit
|
||||
from .import_utils import is_triton_available as is_triton_available
|
||||
from .import_utils import is_vllm_available as is_vllm_available
|
||||
from .import_utils import require_backends as require_backends
|
||||
from .import_utils import requires_dependencies as requires_dependencies
|
||||
from .representation import ReprMixin as ReprMixin
|
||||
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects=_extras)
|
||||
# NOTE: The following exports useful utils from bentoml
|
||||
from . import LazyLoader as LazyLoader
|
||||
from . import LazyType as LazyType
|
||||
from . import analytics as analytics
|
||||
from . import bentoml_cattr as bentoml_cattr
|
||||
from . import codegen as codegen
|
||||
from . import configure_logging as configure_logging
|
||||
from . import dantic as dantic
|
||||
from . import first_not_none as first_not_none
|
||||
from . import reserve_free_port as reserve_free_port
|
||||
from . import set_quiet_mode as set_quiet_mode
|
||||
from . import validate_is_path as validate_is_path
|
||||
from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES
|
||||
from .import_utils import OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES
|
||||
from .import_utils import DummyMetaclass as DummyMetaclass
|
||||
from .import_utils import EnvVarMixin as EnvVarMixin
|
||||
from .import_utils import is_autogptq_available as is_autogptq_available
|
||||
from .import_utils import is_bitsandbytes_available as is_bitsandbytes_available
|
||||
from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available
|
||||
from .import_utils import is_datasets_available as is_datasets_available
|
||||
from .import_utils import is_einops_available as is_einops_available
|
||||
from .import_utils import is_flax_available as is_flax_available
|
||||
from .import_utils import is_jupyter_available as is_jupyter_available
|
||||
from .import_utils import is_jupytext_available as is_jupytext_available
|
||||
from .import_utils import is_notebook_available as is_notebook_available
|
||||
from .import_utils import is_peft_available as is_peft_available
|
||||
from .import_utils import is_tf_available as is_tf_available
|
||||
from .import_utils import is_torch_available as is_torch_available
|
||||
from .import_utils import is_transformers_supports_agent as is_transformers_supports_agent
|
||||
from .import_utils import is_transformers_supports_kbit as is_transformers_supports_kbit
|
||||
from .import_utils import is_triton_available as is_triton_available
|
||||
from .import_utils import is_vllm_available as is_vllm_available
|
||||
from .import_utils import require_backends as require_backends
|
||||
from .import_utils import requires_dependencies as requires_dependencies
|
||||
from .representation import ReprMixin as ReprMixin
|
||||
else:
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects=_extras)
|
||||
|
||||
@@ -30,12 +30,11 @@ import openllm
|
||||
from bentoml._internal.utils import analytics as _internal_analytics
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .._types import P
|
||||
from .._types import T
|
||||
from .._types import P
|
||||
from .._types import T
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
||||
|
||||
# This variable is a proxy that will control BENTOML_DO_NOT_TRACK
|
||||
@@ -43,87 +42,78 @@ OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK"
|
||||
|
||||
DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def do_not_track() -> bool:
|
||||
return DO_NOT_TRACK in ENV_VARS_TRUE_VALUES
|
||||
|
||||
return DO_NOT_TRACK in ENV_VARS_TRUE_VALUES
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _usage_event_debugging() -> bool:
|
||||
# For BentoML developers only - debug and print event payload if turned on
|
||||
return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true"
|
||||
|
||||
# For BentoML developers only - debug and print event payload if turned on
|
||||
return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true"
|
||||
|
||||
def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as err:
|
||||
if _usage_event_debugging():
|
||||
if openllm.utils.get_debug_mode():
|
||||
logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3)
|
||||
else:
|
||||
logger.info("Tracking Error: %s", err)
|
||||
else:
|
||||
logger.debug("Tracking Error: %s", err)
|
||||
|
||||
return wrapper
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as err:
|
||||
if _usage_event_debugging():
|
||||
if openllm.utils.get_debug_mode():
|
||||
logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3)
|
||||
else:
|
||||
logger.info("Tracking Error: %s", err)
|
||||
else:
|
||||
logger.debug("Tracking Error: %s", err)
|
||||
|
||||
return wrapper
|
||||
|
||||
@silent
|
||||
def track(event_properties: attr.AttrsInstance) -> None:
|
||||
if do_not_track():
|
||||
return
|
||||
_internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties))
|
||||
|
||||
if do_not_track():
|
||||
return
|
||||
_internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_bentoml_tracking() -> t.Generator[None, None, None]:
|
||||
original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False))
|
||||
try:
|
||||
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track())
|
||||
yield
|
||||
finally:
|
||||
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value
|
||||
|
||||
original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False))
|
||||
try:
|
||||
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track())
|
||||
yield
|
||||
finally:
|
||||
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value
|
||||
|
||||
class EventMeta:
|
||||
@property
|
||||
def event_name(self) -> str:
|
||||
# camel case to snake case
|
||||
event_name = re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()
|
||||
# remove "_event" suffix
|
||||
suffix_to_remove = "_event"
|
||||
if event_name.endswith(suffix_to_remove):
|
||||
event_name = event_name[: -len(suffix_to_remove)]
|
||||
return event_name
|
||||
|
||||
@property
|
||||
def event_name(self) -> str:
|
||||
# camel case to snake case
|
||||
event_name = re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()
|
||||
# remove "_event" suffix
|
||||
suffix_to_remove = "_event"
|
||||
if event_name.endswith(suffix_to_remove):
|
||||
event_name = event_name[:-len(suffix_to_remove)]
|
||||
return event_name
|
||||
|
||||
@attr.define
|
||||
class OpenllmCliEvent(EventMeta):
|
||||
cmd_group: str
|
||||
cmd_name: str
|
||||
openllm_version: str = importlib.metadata.version("openllm")
|
||||
|
||||
# NOTE: reserved for the do_not_track logics
|
||||
duration_in_ms: t.Any = attr.field(default=None)
|
||||
error_type: str = attr.field(default=None)
|
||||
return_code: int = attr.field(default=None)
|
||||
cmd_group: str
|
||||
cmd_name: str
|
||||
openllm_version: str = importlib.metadata.version("openllm")
|
||||
|
||||
# NOTE: reserved for the do_not_track logics
|
||||
duration_in_ms: t.Any = attr.field(default=None)
|
||||
error_type: str = attr.field(default=None)
|
||||
return_code: int = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class StartInitEvent(EventMeta):
|
||||
model_name: str
|
||||
llm_config: t.Dict[str, t.Any] = attr.field(default=None)
|
||||
|
||||
@staticmethod
|
||||
def handler(llm_config: openllm.LLMConfig) -> StartInitEvent:
|
||||
return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump())
|
||||
model_name: str
|
||||
llm_config: t.Dict[str, t.Any] = attr.field(default=None)
|
||||
|
||||
@staticmethod
|
||||
def handler(llm_config: openllm.LLMConfig) -> StartInitEvent:
|
||||
return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump())
|
||||
|
||||
def track_start_init(llm_config: openllm.LLMConfig) -> None:
|
||||
if do_not_track():
|
||||
return
|
||||
track(StartInitEvent.handler(llm_config))
|
||||
if do_not_track():
|
||||
return
|
||||
track(StartInitEvent.handler(llm_config))
|
||||
|
||||
@@ -26,19 +26,19 @@ from pathlib import Path
|
||||
import orjson
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
from fs.base import FS
|
||||
|
||||
import openllm
|
||||
import openllm
|
||||
|
||||
from .._types import AnyCallable
|
||||
from .._types import DictStrAny
|
||||
from .._types import ListStr
|
||||
from .._types import AnyCallable
|
||||
from .._types import DictStrAny
|
||||
from .._types import ListStr
|
||||
|
||||
PartialAny = functools.partial[t.Any]
|
||||
PartialAny = functools.partial[t.Any]
|
||||
else:
|
||||
DictStrAny = dict
|
||||
ListStr = list
|
||||
PartialAny = functools.partial
|
||||
DictStrAny = dict
|
||||
ListStr = list
|
||||
PartialAny = functools.partial
|
||||
|
||||
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
|
||||
|
||||
@@ -47,274 +47,206 @@ logger = logging.getLogger(__name__)
|
||||
OPENLLM_MODEL_NAME = "# openllm: model name"
|
||||
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
|
||||
|
||||
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: t.LiteralString = "__model_name__"
|
||||
model_keyword: t.LiteralString = "__model_name__"
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
"""The formatter that extends model_name to be formatted the 'service.py'."""
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
def __init__(self, model_name: str):
|
||||
"""The formatter that extends model_name to be formatted the 'service.py'."""
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
|
||||
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
return super().vformat(format_string, (), {self.model_keyword: self.model_name})
|
||||
|
||||
def can_format(self, value: str) -> bool:
|
||||
try:
|
||||
self.parse(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
return super().vformat(format_string, (), {self.model_keyword: self.model_name})
|
||||
|
||||
def can_format(self, value: str) -> bool:
|
||||
try:
|
||||
self.parse(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: t.LiteralString = "__model_id__"
|
||||
|
||||
model_keyword: t.LiteralString = "__model_id__"
|
||||
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: t.LiteralString = "__model_adapter_map__"
|
||||
|
||||
model_keyword: t.LiteralString = "__model_adapter_map__"
|
||||
|
||||
_service_file = Path(__file__).parent.parent / "_service.py"
|
||||
|
||||
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
|
||||
from . import DEBUG
|
||||
from . import DEBUG
|
||||
|
||||
model_name = llm.config["model_name"]
|
||||
model_name = llm.config["model_name"]
|
||||
|
||||
logger.debug("Generating service for %s", model_name)
|
||||
logger.debug("Generating service for %s", model_name)
|
||||
|
||||
with open(_service_file.__fspath__(), "r") as f:
|
||||
src_contents = f.readlines()
|
||||
with open(_service_file.__fspath__(), "r") as f:
|
||||
src_contents = f.readlines()
|
||||
|
||||
# modify with model name
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it:
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelNameFormatter(model_name).vformat(it)[: -(len(OPENLLM_MODEL_NAME) + 3)] + "\n"
|
||||
)
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it:
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[
|
||||
: -(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)
|
||||
]
|
||||
+ "\n"
|
||||
)
|
||||
# modify with model name
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it:
|
||||
src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + "\n")
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it:
|
||||
src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + "\n")
|
||||
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
|
||||
|
||||
if DEBUG:
|
||||
logger.info("Generated script:\n%s", script)
|
||||
|
||||
llm_fs.writetext(llm.config["service_name"], script)
|
||||
if DEBUG:
|
||||
logger.info("Generated script:\n%s", script)
|
||||
|
||||
llm_fs.writetext(llm.config["service_name"], script)
|
||||
|
||||
# NOTE: The following ins extracted from attrs internal APIs
|
||||
|
||||
# sentinel object for unequivocal object() getattr
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool:
|
||||
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
|
||||
attr = getattr(cls, attrib_name, _sentinel)
|
||||
if attr is _sentinel:
|
||||
return False
|
||||
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
|
||||
attr = getattr(cls, attrib_name, _sentinel)
|
||||
if attr is _sentinel:
|
||||
return False
|
||||
|
||||
for base_cls in cls.__mro__[1:]:
|
||||
a = getattr(base_cls, attrib_name, None)
|
||||
if attr is a:
|
||||
return False
|
||||
|
||||
return True
|
||||
for base_cls in cls.__mro__[1:]:
|
||||
a = getattr(base_cls, attrib_name, None)
|
||||
if attr is a:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_annotations(cls: type[t.Any]) -> DictStrAny:
|
||||
"""Get annotations for *cls*."""
|
||||
if has_own_attribute(cls, "__annotations__"):
|
||||
return cls.__annotations__
|
||||
"""Get annotations for *cls*."""
|
||||
if has_own_attribute(cls, "__annotations__"):
|
||||
return cls.__annotations__
|
||||
|
||||
return DictStrAny()
|
||||
|
||||
|
||||
_classvar_prefixes = (
|
||||
"typing.ClassVar",
|
||||
"t.ClassVar",
|
||||
"ClassVar",
|
||||
"typing_extensions.ClassVar",
|
||||
)
|
||||
return DictStrAny()
|
||||
|
||||
_classvar_prefixes = ("typing.ClassVar", "t.ClassVar", "ClassVar", "typing_extensions.ClassVar",)
|
||||
|
||||
def is_class_var(annot: str | t.Any) -> bool:
|
||||
"""Check whether *annot* is a typing.ClassVar.
|
||||
"""Check whether *annot* is a typing.ClassVar.
|
||||
|
||||
The string comparison hack is used to avoid evaluating all string
|
||||
annotations which would put attrs-based classes at a performance
|
||||
disadvantage compared to plain old classes.
|
||||
"""
|
||||
annot = str(annot)
|
||||
The string comparison hack is used to avoid evaluating all string
|
||||
annotations which would put attrs-based classes at a performance
|
||||
disadvantage compared to plain old classes.
|
||||
"""
|
||||
annot = str(annot)
|
||||
|
||||
# Annotation can be quoted.
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
|
||||
annot = annot[1:-1]
|
||||
|
||||
return annot.startswith(_classvar_prefixes)
|
||||
# Annotation can be quoted.
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
|
||||
annot = annot[1:-1]
|
||||
|
||||
return annot.startswith(_classvar_prefixes)
|
||||
|
||||
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
|
||||
"""Add __module__ and __qualname__ to a *method* if possible."""
|
||||
try: method_or_cls.__module__ = cls.__module__
|
||||
except AttributeError: pass
|
||||
try: method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
|
||||
except AttributeError: pass
|
||||
try: method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}."
|
||||
except AttributeError: pass
|
||||
return method_or_cls
|
||||
|
||||
"""Add __module__ and __qualname__ to a *method* if possible."""
|
||||
try:
|
||||
method_or_cls.__module__ = cls.__module__
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}."
|
||||
except AttributeError:
|
||||
pass
|
||||
return method_or_cls
|
||||
|
||||
# Exec the script with the given global (globs) and local (locs) variables.
|
||||
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None: eval(compile(script, filename, "exec"), globs, locs) # noqa: S307
|
||||
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None:
|
||||
eval(compile(script, filename, "exec"), globs, locs) # noqa: S307
|
||||
|
||||
# ported from attrs
|
||||
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
|
||||
"""Create the method with the script given and return the method object."""
|
||||
locs: DictStrAny = {}
|
||||
|
||||
# In order of debuggers like PDB being able to step through the code,
|
||||
# we add a fake linecache entry.
|
||||
count = 1
|
||||
base_filename = filename
|
||||
while True:
|
||||
linecache_tuple = (len(script), None, script.splitlines(True), filename)
|
||||
old_val = linecache.cache.setdefault(filename, linecache_tuple)
|
||||
if old_val == linecache_tuple: break
|
||||
else:
|
||||
filename = f"{base_filename[:-1]}-{count}>"
|
||||
count += 1
|
||||
_compile_and_eval(script, globs, locs, filename)
|
||||
return locs[name]
|
||||
"""Create the method with the script given and return the method object."""
|
||||
locs: DictStrAny = {}
|
||||
|
||||
# In order of debuggers like PDB being able to step through the code,
|
||||
# we add a fake linecache entry.
|
||||
count = 1
|
||||
base_filename = filename
|
||||
while True:
|
||||
linecache_tuple = (len(script), None, script.splitlines(True), filename)
|
||||
old_val = linecache.cache.setdefault(filename, linecache_tuple)
|
||||
if old_val == linecache_tuple: break
|
||||
else:
|
||||
filename = f"{base_filename[:-1]}-{count}>"
|
||||
count += 1
|
||||
_compile_and_eval(script, globs, locs, filename)
|
||||
return locs[name]
|
||||
|
||||
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]:
|
||||
"""Create a tuple subclass to hold class attributes.
|
||||
"""Create a tuple subclass to hold class attributes.
|
||||
|
||||
The subclass is a bare tuple with properties for names.
|
||||
The subclass is a bare tuple with properties for names.
|
||||
|
||||
class MyClassAttributes(tuple):
|
||||
__slots__ = ()
|
||||
x = property(itemgetter(0))
|
||||
"""
|
||||
from . import SHOW_CODEGEN
|
||||
class MyClassAttributes(tuple):
|
||||
__slots__ = ()
|
||||
x = property(itemgetter(0))
|
||||
"""
|
||||
from . import SHOW_CODEGEN
|
||||
|
||||
attr_class_name = f"{cls_name}Attributes"
|
||||
attr_class_template = [
|
||||
f"class {attr_class_name}(tuple):",
|
||||
" __slots__ = ()",
|
||||
]
|
||||
if attr_names:
|
||||
for i, attr_name in enumerate(attr_names): attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
|
||||
else: attr_class_template.append(" pass")
|
||||
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
|
||||
if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template))
|
||||
_compile_and_eval("\n".join(attr_class_template), globs)
|
||||
return globs[attr_class_name]
|
||||
attr_class_name = f"{cls_name}Attributes"
|
||||
attr_class_template = [f"class {attr_class_name}(tuple):", " __slots__ = ()",]
|
||||
if attr_names:
|
||||
for i, attr_name in enumerate(attr_names):
|
||||
attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
|
||||
else:
|
||||
attr_class_template.append(" pass")
|
||||
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
|
||||
if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template))
|
||||
_compile_and_eval("\n".join(attr_class_template), globs)
|
||||
return globs[attr_class_name]
|
||||
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str:
|
||||
return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None,) -> AnyCallable:
|
||||
from . import SHOW_CODEGEN
|
||||
|
||||
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
|
||||
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
|
||||
if annotations: meth.__annotations__ = annotations
|
||||
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
|
||||
|
||||
def generate_function(
|
||||
typ: type[t.Any],
|
||||
func_name: str,
|
||||
lines: list[str] | None,
|
||||
args: tuple[str, ...] | None,
|
||||
globs: dict[str, t.Any],
|
||||
annotations: dict[str, t.Any] | None = None,
|
||||
) -> AnyCallable:
|
||||
from . import SHOW_CODEGEN
|
||||
return meth
|
||||
|
||||
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
|
||||
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
|
||||
if annotations: meth.__annotations__ = annotations
|
||||
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
|
||||
def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: t.LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable:
|
||||
from . import dantic
|
||||
from . import field_env_key
|
||||
|
||||
return meth
|
||||
def identity(_: str, x_value: t.Any) -> t.Any:
|
||||
return x_value
|
||||
|
||||
default_callback = identity if default_callback is None else default_callback
|
||||
|
||||
def make_env_transformer(
|
||||
cls: type[openllm.LLMConfig],
|
||||
model_name: str,
|
||||
suffix: t.LiteralString | None = None,
|
||||
default_callback: t.Callable[[str, t.Any], t.Any] | None = None,
|
||||
globs: DictStrAny | None = None,
|
||||
) -> AnyCallable:
|
||||
from . import dantic
|
||||
from . import field_env_key
|
||||
globs = {} if globs is None else globs
|
||||
globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,})
|
||||
|
||||
def identity(_: str, x_value: t.Any) -> t.Any:
|
||||
return x_value
|
||||
lines: ListStr = [
|
||||
"__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", "return [", " f.evolve(", " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", " metadata={", " 'env': f.metadata.get('env', __env(f.name)),", " 'description': f.metadata.get('description', '(not provided)'),", " },",
|
||||
" )", " for f in fields", "]",
|
||||
]
|
||||
fields_ann = "list[attr.Attribute[t.Any]]"
|
||||
|
||||
default_callback = identity if default_callback is None else default_callback
|
||||
|
||||
globs = {} if globs is None else globs
|
||||
globs.update(
|
||||
{
|
||||
"__populate_env": dantic.env_converter,
|
||||
"__default_callback": default_callback,
|
||||
"__field_env": field_env_key,
|
||||
"__suffix": suffix or "",
|
||||
"__model_name": model_name,
|
||||
}
|
||||
)
|
||||
|
||||
lines: ListStr = [
|
||||
"__env = lambda field_name: __field_env(__model_name, field_name, __suffix)",
|
||||
"return [",
|
||||
" f.evolve(",
|
||||
" default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),",
|
||||
" metadata={",
|
||||
" 'env': f.metadata.get('env', __env(f.name)),",
|
||||
" 'description': f.metadata.get('description', '(not provided)'),",
|
||||
" },",
|
||||
" )",
|
||||
" for f in fields",
|
||||
"]",
|
||||
]
|
||||
fields_ann = "list[attr.Attribute[t.Any]]"
|
||||
|
||||
return generate_function(
|
||||
cls,
|
||||
"__auto_env",
|
||||
lines,
|
||||
args=("_", "fields"),
|
||||
globs=globs,
|
||||
annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann},
|
||||
)
|
||||
return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann},)
|
||||
|
||||
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
"""Enhance sdk with nice repr that plays well with your brain."""
|
||||
from .representation import ReprMixin
|
||||
"""Enhance sdk with nice repr that plays well with your brain."""
|
||||
from .representation import ReprMixin
|
||||
|
||||
if name is None: name = func.__name__.strip("_")
|
||||
_signatures = inspect.signature(func).parameters
|
||||
def _repr(self: ReprMixin) -> str: return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
|
||||
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
|
||||
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
|
||||
else: doc = func.__doc__
|
||||
return t.cast(_T, functools.update_wrapper(
|
||||
types.new_class(
|
||||
name,
|
||||
(PartialAny, ReprMixin),
|
||||
exec_body=lambda ns: ns.update(
|
||||
{
|
||||
"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]),
|
||||
"__repr_args__": _repr_args,
|
||||
"__repr__": _repr,
|
||||
"__doc__": inspect.cleandoc(doc),
|
||||
"__module__": "openllm",
|
||||
}
|
||||
),
|
||||
)(func, **attrs),
|
||||
func,
|
||||
))
|
||||
if name is None: name = func.__name__.strip("_")
|
||||
_signatures = inspect.signature(func).parameters
|
||||
|
||||
def _repr(self: ReprMixin) -> str:
|
||||
return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
|
||||
|
||||
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]:
|
||||
return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
|
||||
|
||||
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
|
||||
else: doc = func.__doc__
|
||||
return t.cast(_T, functools.update_wrapper(types.new_class(name, (PartialAny, ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,))
|
||||
|
||||
@@ -31,478 +31,432 @@ from click import shell_completion as sc
|
||||
from click import types as click_types
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import _ValidatorType
|
||||
from attr import _ValidatorType
|
||||
|
||||
from .._types import ListAny
|
||||
from .._types import ListAny
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command])
|
||||
|
||||
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: type[t.Any] | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata["env"]
|
||||
dasherized = inflection.dasherize(name)
|
||||
underscored = inflection.underscore(name)
|
||||
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
suffix_sampling: bool = False,
|
||||
) -> t.Callable[[FC], FC]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata["env"]
|
||||
dasherized = inflection.dasherize(name)
|
||||
underscored = inflection.underscore(name)
|
||||
if typ in (None, attr.NOTHING):
|
||||
typ = field.type
|
||||
if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
|
||||
|
||||
if typ in (None, attr.NOTHING):
|
||||
typ = field.type
|
||||
if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
|
||||
|
||||
full_option_name = f"--{dasherized}"
|
||||
if field.type is bool: full_option_name += f"/--no-{dasherized}"
|
||||
if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
|
||||
elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
|
||||
else: identifier = f"{model_name}_{underscored}"
|
||||
|
||||
return cog.optgroup.option(
|
||||
identifier,
|
||||
full_option_name,
|
||||
type=parse_type(typ),
|
||||
required=field.default is attr.NOTHING,
|
||||
default=field.default if field.default not in (attr.NOTHING, None) else None,
|
||||
show_default=True,
|
||||
multiple=allows_multiple(typ) if typ else False,
|
||||
help=field.metadata.get("description", "(No description provided)"),
|
||||
show_envvar=True,
|
||||
envvar=envvar,
|
||||
)
|
||||
full_option_name = f"--{dasherized}"
|
||||
if field.type is bool: full_option_name += f"/--no-{dasherized}"
|
||||
if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
|
||||
elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
|
||||
else: identifier = f"{model_name}_{underscored}"
|
||||
|
||||
return cog.optgroup.option(identifier, full_option_name, type=parse_type(typ), required=field.default is attr.NOTHING, default=field.default if field.default not in (attr.NOTHING, None) else None, show_default=True, multiple=allows_multiple(typ) if typ else False, help=field.metadata.get("description", "(No description provided)"), show_envvar=True, envvar=envvar,)
|
||||
|
||||
def env_converter(value: t.Any, env: str | None = None) -> t.Any:
|
||||
if env is not None:
|
||||
value = os.environ.get(env, value)
|
||||
if value is not None and isinstance(value, str):
|
||||
try:
|
||||
return orjson.loads(value.lower())
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
|
||||
return value
|
||||
if env is not None:
|
||||
value = os.environ.get(env, value)
|
||||
if value is not None and isinstance(value, str):
|
||||
try:
|
||||
return orjson.loads(value.lower())
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
|
||||
return value
|
||||
|
||||
def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[_T] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any,) -> t.Any:
|
||||
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
|
||||
|
||||
def Field(
|
||||
default: t.Any = None,
|
||||
*,
|
||||
ge: int | float | None = None,
|
||||
le: int | float | None = None,
|
||||
validator: _ValidatorType[_T] | None = None,
|
||||
description: str | None = None,
|
||||
env: str | None = None,
|
||||
auto_default: bool = False,
|
||||
use_default_converter: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> t.Any:
|
||||
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
|
||||
By default, if both validator and ge are provided, then then ge will be
|
||||
piped into first, then all of the other validator will be run afterwards.
|
||||
|
||||
By default, if both validator and ge are provided, then then ge will be
|
||||
piped into first, then all of the other validator will be run afterwards.
|
||||
Args:
|
||||
default: The default value for ``dantic.Field``. Defaults to ``None``.
|
||||
ge: Greater than or equal to. Defaults to None.
|
||||
le: Less than or equal to. Defaults to None.
|
||||
validator: Optional attrs-compatible validators type. Default to None
|
||||
description: the documentation for the field. Defaults to None.
|
||||
env: the environment variable to read from. Defaults to None.
|
||||
auto_default: a bool indicating whether to use the default value as the environment.
|
||||
Defaults to False. If set to True, the behaviour of this Field will also depends
|
||||
on kw_only. If kw_only=True, the this field will become 'Required' and the default
|
||||
value is omitted. If kw_only=False, then the default value will be used as before.
|
||||
use_default_converter: a bool indicating whether to use the default converter. Defaults
|
||||
to True. If set to False, then the default converter will not be used.
|
||||
The default converter converts a given value from the environment variable
|
||||
for this given Field.
|
||||
**attrs: The rest of the arguments are passed to attr.field
|
||||
"""
|
||||
metadata = attrs.pop("metadata", {})
|
||||
if description is None:
|
||||
description = "(No description provided)"
|
||||
metadata["description"] = description
|
||||
if env is not None:
|
||||
metadata["env"] = env
|
||||
piped: list[_ValidatorType[t.Any]] = []
|
||||
|
||||
Args:
|
||||
default: The default value for ``dantic.Field``. Defaults to ``None``.
|
||||
ge: Greater than or equal to. Defaults to None.
|
||||
le: Less than or equal to. Defaults to None.
|
||||
validator: Optional attrs-compatible validators type. Default to None
|
||||
description: the documentation for the field. Defaults to None.
|
||||
env: the environment variable to read from. Defaults to None.
|
||||
auto_default: a bool indicating whether to use the default value as the environment.
|
||||
Defaults to False. If set to True, the behaviour of this Field will also depends
|
||||
on kw_only. If kw_only=True, the this field will become 'Required' and the default
|
||||
value is omitted. If kw_only=False, then the default value will be used as before.
|
||||
use_default_converter: a bool indicating whether to use the default converter. Defaults
|
||||
to True. If set to False, then the default converter will not be used.
|
||||
The default converter converts a given value from the environment variable
|
||||
for this given Field.
|
||||
**attrs: The rest of the arguments are passed to attr.field
|
||||
"""
|
||||
metadata = attrs.pop("metadata", {})
|
||||
if description is None:
|
||||
description = "(No description provided)"
|
||||
metadata["description"] = description
|
||||
if env is not None:
|
||||
metadata["env"] = env
|
||||
piped: list[_ValidatorType[t.Any]] = []
|
||||
converter = attrs.pop("converter", None)
|
||||
if use_default_converter:
|
||||
converter = functools.partial(env_converter, env=env)
|
||||
|
||||
converter = attrs.pop("converter", None)
|
||||
if use_default_converter:
|
||||
converter = functools.partial(env_converter, env=env)
|
||||
if ge is not None:
|
||||
piped.append(attr.validators.ge(ge))
|
||||
if le is not None:
|
||||
piped.append(attr.validators.le(le))
|
||||
if validator is not None:
|
||||
piped.append(validator)
|
||||
|
||||
if ge is not None:
|
||||
piped.append(attr.validators.ge(ge))
|
||||
if le is not None:
|
||||
piped.append(attr.validators.le(le))
|
||||
if validator is not None:
|
||||
piped.append(validator)
|
||||
if len(piped) == 0:
|
||||
_validator = None
|
||||
elif len(piped) == 1:
|
||||
_validator = piped[0]
|
||||
else:
|
||||
_validator = attr.validators.and_(*piped)
|
||||
|
||||
if len(piped) == 0:
|
||||
_validator = None
|
||||
elif len(piped) == 1:
|
||||
_validator = piped[0]
|
||||
else:
|
||||
_validator = attr.validators.and_(*piped)
|
||||
factory = attrs.pop("factory", None)
|
||||
if factory is not None and default is not None:
|
||||
raise RuntimeError("'factory' and 'default' are mutually exclusive.")
|
||||
# NOTE: the behaviour of this is we will respect factory over the default
|
||||
if factory is not None:
|
||||
attrs["factory"] = factory
|
||||
else:
|
||||
attrs["default"] = default
|
||||
|
||||
factory = attrs.pop("factory", None)
|
||||
if factory is not None and default is not None:
|
||||
raise RuntimeError("'factory' and 'default' are mutually exclusive.")
|
||||
# NOTE: the behaviour of this is we will respect factory over the default
|
||||
if factory is not None:
|
||||
attrs["factory"] = factory
|
||||
else:
|
||||
attrs["default"] = default
|
||||
|
||||
kw_only = attrs.pop("kw_only", False)
|
||||
if auto_default and kw_only:
|
||||
attrs.pop("default")
|
||||
|
||||
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
kw_only = attrs.pop("kw_only", False)
|
||||
if auto_default and kw_only:
|
||||
attrs.pop("default")
|
||||
|
||||
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
|
||||
def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType]:
|
||||
"""Transforms the pydantic field's type into a click-compatible type.
|
||||
"""Transforms the pydantic field's type into a click-compatible type.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
ParamType: click type equivalent
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
|
||||
if t.get_origin(field_type) is t.Union:
|
||||
raise NotImplementedError("Unions are not supported")
|
||||
# enumeration strings or other Enum derivatives
|
||||
if lenient_issubclass(field_type, Enum):
|
||||
return EnumChoice(enum=field_type, case_sensitive=True)
|
||||
# literals are enum-like with way less functionality
|
||||
if is_literal(field_type):
|
||||
return LiteralChoice(value=field_type, case_sensitive=True)
|
||||
# modules, classes, functions
|
||||
if is_typing(field_type):
|
||||
return ModuleType()
|
||||
# entire dictionaries:
|
||||
# using a Dict, convert in advance
|
||||
if is_mapping(field_type):
|
||||
return JsonType()
|
||||
# list, List[p], Tuple[p], Set[p] and so on
|
||||
if is_container(field_type):
|
||||
return parse_container_args(field_type)
|
||||
# bytes are not natively supported by click
|
||||
if lenient_issubclass(field_type, bytes):
|
||||
return BytesType()
|
||||
# return the current type: it should be a primitive
|
||||
return field_type
|
||||
Returns:
|
||||
ParamType: click type equivalent
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
|
||||
if t.get_origin(field_type) is t.Union:
|
||||
raise NotImplementedError("Unions are not supported")
|
||||
# enumeration strings or other Enum derivatives
|
||||
if lenient_issubclass(field_type, Enum):
|
||||
return EnumChoice(enum=field_type, case_sensitive=True)
|
||||
# literals are enum-like with way less functionality
|
||||
if is_literal(field_type):
|
||||
return LiteralChoice(value=field_type, case_sensitive=True)
|
||||
# modules, classes, functions
|
||||
if is_typing(field_type):
|
||||
return ModuleType()
|
||||
# entire dictionaries:
|
||||
# using a Dict, convert in advance
|
||||
if is_mapping(field_type):
|
||||
return JsonType()
|
||||
# list, List[p], Tuple[p], Set[p] and so on
|
||||
if is_container(field_type):
|
||||
return parse_container_args(field_type)
|
||||
# bytes are not natively supported by click
|
||||
if lenient_issubclass(field_type, bytes):
|
||||
return BytesType()
|
||||
# return the current type: it should be a primitive
|
||||
return field_type
|
||||
|
||||
def is_typing(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a module-like type.
|
||||
"""Checks whether the current type is a module-like type.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
bool: true if the type is itself a type
|
||||
"""
|
||||
raw = t.get_origin(field_type)
|
||||
if raw is None:
|
||||
return False
|
||||
if raw is type or raw is t.Type:
|
||||
return True
|
||||
Returns:
|
||||
bool: true if the type is itself a type
|
||||
"""
|
||||
raw = t.get_origin(field_type)
|
||||
if raw is None:
|
||||
return False
|
||||
|
||||
if raw is type or raw is t.Type:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_literal(field_type: type) -> bool:
|
||||
"""Checks whether the given field type is a Literal type or not.
|
||||
"""Checks whether the given field type is a Literal type or not.
|
||||
|
||||
Literals are weird: isinstance and subclass do not work, so you compare
|
||||
the origin with the Literal declaration itself.
|
||||
Literals are weird: isinstance and subclass do not work, so you compare
|
||||
the origin with the Literal declaration itself.
|
||||
|
||||
Args:
|
||||
field_type: current pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true if Literal type, false otherwise
|
||||
"""
|
||||
origin = t.get_origin(field_type)
|
||||
return origin is not None and origin is t.Literal
|
||||
Args:
|
||||
field_type: current pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true if Literal type, false otherwise
|
||||
"""
|
||||
origin = t.get_origin(field_type)
|
||||
return origin is not None and origin is t.Literal
|
||||
|
||||
class ModuleType(ParamType):
|
||||
name = "module"
|
||||
name = "module"
|
||||
|
||||
def _import_object(self, value: str) -> t.Any:
|
||||
module_name, class_name = value.rsplit(".", maxsplit=1)
|
||||
if not all(s.isidentifier() for s in module_name.split(".")):
|
||||
raise ValueError(f"'{value}' is not a valid module name")
|
||||
if not class_name.isidentifier():
|
||||
raise ValueError(f"Variable '{class_name}' is not a valid identifier")
|
||||
def _import_object(self, value: str) -> t.Any:
|
||||
module_name, class_name = value.rsplit(".", maxsplit=1)
|
||||
if not all(s.isidentifier() for s in module_name.split(".")):
|
||||
raise ValueError(f"'{value}' is not a valid module name")
|
||||
if not class_name.isidentifier():
|
||||
raise ValueError(f"Variable '{class_name}' is not a valid identifier")
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
if class_name:
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
|
||||
|
||||
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
return self._import_object(value)
|
||||
return value
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
|
||||
module = importlib.import_module(module_name)
|
||||
if class_name:
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
|
||||
|
||||
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
return self._import_object(value)
|
||||
return value
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
|
||||
|
||||
class EnumChoice(click.Choice):
|
||||
name = "enum"
|
||||
name = "enum"
|
||||
|
||||
def __init__(self, enum: Enum, case_sensitive: bool = False):
|
||||
"""Enum type support for click that extends ``click.Choice``.
|
||||
def __init__(self, enum: Enum, case_sensitive: bool = False):
|
||||
"""Enum type support for click that extends ``click.Choice``.
|
||||
|
||||
Args:
|
||||
enum: Given enum
|
||||
case_sensitive: Whether this choice should be case case_sensitive.
|
||||
"""
|
||||
self.mapping = enum
|
||||
self.internal_type = type(enum)
|
||||
choices: ListAny = [e.name for e in enum.__class__]
|
||||
super().__init__(choices, case_sensitive)
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
|
||||
if isinstance(value, self.internal_type):
|
||||
return value
|
||||
result = super().convert(value, param, ctx)
|
||||
if isinstance(result, str):
|
||||
result = self.internal_type[result]
|
||||
return result
|
||||
Args:
|
||||
enum: Given enum
|
||||
case_sensitive: Whether this choice should be case case_sensitive.
|
||||
"""
|
||||
self.mapping = enum
|
||||
self.internal_type = type(enum)
|
||||
choices: ListAny = [e.name for e in enum.__class__]
|
||||
super().__init__(choices, case_sensitive)
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
|
||||
if isinstance(value, self.internal_type):
|
||||
return value
|
||||
result = super().convert(value, param, ctx)
|
||||
if isinstance(result, str):
|
||||
result = self.internal_type[result]
|
||||
return result
|
||||
|
||||
class LiteralChoice(EnumChoice):
|
||||
name = "literal"
|
||||
|
||||
def __init__(self, value: t.Any, case_sensitive: bool = False):
|
||||
"""Literal support for click."""
|
||||
# expect every literal value to belong to the same primitive type
|
||||
values = list(value.__args__)
|
||||
item_type = type(values[0])
|
||||
if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
|
||||
_mapping = {str(v): v for v in values}
|
||||
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
|
||||
self.internal_type = item_type
|
||||
name = "literal"
|
||||
|
||||
def __init__(self, value: t.Any, case_sensitive: bool = False):
|
||||
"""Literal support for click."""
|
||||
# expect every literal value to belong to the same primitive type
|
||||
values = list(value.__args__)
|
||||
item_type = type(values[0])
|
||||
if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
|
||||
_mapping = {str(v): v for v in values}
|
||||
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
|
||||
self.internal_type = item_type
|
||||
|
||||
def allows_multiple(field_type: type[t.Any]) -> bool:
|
||||
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
|
||||
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
|
||||
|
||||
For containers, it exploits click's support for lists and such to use the same option multiple times
|
||||
to create a complex object: `python run.py --subsets train --subsets test`
|
||||
# becomes `subsets: ["train", "test"]`.
|
||||
For containers, it exploits click's support for lists and such to use the same option multiple times
|
||||
to create a complex object: `python run.py --subsets train --subsets test`
|
||||
# becomes `subsets: ["train", "test"]`.
|
||||
|
||||
Args:
|
||||
field_type: pydantic type.
|
||||
Args:
|
||||
field_type: pydantic type.
|
||||
|
||||
Returns:
|
||||
bool: true if it's a composite field (lists, containers and so on), false otherwise
|
||||
"""
|
||||
# Early out for mappings, since it's better to deal with them using strings.
|
||||
if is_mapping(field_type):
|
||||
return False
|
||||
# Activate multiple option for (simple) container types
|
||||
if is_container(field_type):
|
||||
args = parse_container_args(field_type)
|
||||
# A non-composite type has a single argument, such as 'List[int]'
|
||||
# A composite type has a tuple of arguments, like 'Tuple[str, int, int]'.
|
||||
# For the moment, only non-composite types are allowed.
|
||||
return not isinstance(args, tuple)
|
||||
Returns:
|
||||
bool: true if it's a composite field (lists, containers and so on), false otherwise
|
||||
"""
|
||||
# Early out for mappings, since it's better to deal with them using strings.
|
||||
if is_mapping(field_type):
|
||||
return False
|
||||
|
||||
# Activate multiple option for (simple) container types
|
||||
if is_container(field_type):
|
||||
args = parse_container_args(field_type)
|
||||
# A non-composite type has a single argument, such as 'List[int]'
|
||||
# A composite type has a tuple of arguments, like 'Tuple[str, int, int]'.
|
||||
# For the moment, only non-composite types are allowed.
|
||||
return not isinstance(args, tuple)
|
||||
return False
|
||||
|
||||
def is_mapping(field_type: type) -> bool:
|
||||
"""Checks whether this field represents a dictionary or JSON object.
|
||||
"""Checks whether this field represents a dictionary or JSON object.
|
||||
|
||||
Args:
|
||||
field_type (type): pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true when the field is a dict-like object, false otherwise.
|
||||
"""
|
||||
# Early out for standard containers.
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Mapping): return True
|
||||
# for everything else or when the typing is more complex, check its origin
|
||||
origin = t.get_origin(field_type)
|
||||
if origin is None: return False
|
||||
return lenient_issubclass(origin, t.Mapping)
|
||||
Args:
|
||||
field_type (type): pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true when the field is a dict-like object, false otherwise.
|
||||
"""
|
||||
# Early out for standard containers.
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Mapping): return True
|
||||
# for everything else or when the typing is more complex, check its origin
|
||||
origin = t.get_origin(field_type)
|
||||
if origin is None: return False
|
||||
return lenient_issubclass(origin, t.Mapping)
|
||||
|
||||
def is_container(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
|
||||
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
bool: true if a container, false otherwise
|
||||
"""
|
||||
# do not consider strings or byte arrays as containers
|
||||
if field_type in (str, bytes): return False
|
||||
# Early out for standard containers: list, tuple, range
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Container): return True
|
||||
origin = t.get_origin(field_type)
|
||||
# Early out for non-typing objects
|
||||
if origin is None: return False
|
||||
return lenient_issubclass(origin, t.Container)
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
bool: true if a container, false otherwise
|
||||
"""
|
||||
# do not consider strings or byte arrays as containers
|
||||
if field_type in (str, bytes): return False
|
||||
# Early out for standard containers: list, tuple, range
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Container): return True
|
||||
origin = t.get_origin(field_type)
|
||||
# Early out for non-typing objects
|
||||
if origin is None: return False
|
||||
return lenient_issubclass(origin, t.Container)
|
||||
|
||||
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType]:
|
||||
"""Parses the arguments inside a container type (lists, tuples and so on).
|
||||
"""Parses the arguments inside a container type (lists, tuples and so on).
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
ParamType | tuple[ParamType]: single click-compatible type or a tuple
|
||||
"""
|
||||
if not is_container(field_type):
|
||||
raise ValueError("Field type is not a container type.")
|
||||
args = t.get_args(field_type)
|
||||
# Early out for untyped containers: standard lists, tuples, List[Any]
|
||||
# Use strings when the type is unknown, avoid click's type guessing
|
||||
if len(args) == 0:
|
||||
return click_types.convert_type(str)
|
||||
# Early out for homogenous containers: Tuple[int], List[str]
|
||||
if len(args) == 1:
|
||||
return parse_single_arg(args[0])
|
||||
# Early out for homogenous tuples of indefinite length: Tuple[int, ...]
|
||||
if len(args) == 2 and args[1] is Ellipsis:
|
||||
return parse_single_arg(args[0])
|
||||
# Then deal with fixed-length containers: Tuple[str, int, int]
|
||||
return tuple(parse_single_arg(arg) for arg in args)
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
ParamType | tuple[ParamType]: single click-compatible type or a tuple
|
||||
"""
|
||||
if not is_container(field_type):
|
||||
raise ValueError("Field type is not a container type.")
|
||||
args = t.get_args(field_type)
|
||||
# Early out for untyped containers: standard lists, tuples, List[Any]
|
||||
# Use strings when the type is unknown, avoid click's type guessing
|
||||
if len(args) == 0:
|
||||
return click_types.convert_type(str)
|
||||
# Early out for homogenous containers: Tuple[int], List[str]
|
||||
if len(args) == 1:
|
||||
return parse_single_arg(args[0])
|
||||
# Early out for homogenous tuples of indefinite length: Tuple[int, ...]
|
||||
if len(args) == 2 and args[1] is Ellipsis:
|
||||
return parse_single_arg(args[0])
|
||||
# Then deal with fixed-length containers: Tuple[str, int, int]
|
||||
return tuple(parse_single_arg(arg) for arg in args)
|
||||
|
||||
def parse_single_arg(arg: type) -> ParamType:
|
||||
"""Returns the click-compatible type for container origin types.
|
||||
"""Returns the click-compatible type for container origin types.
|
||||
|
||||
In this case, returns string when it's not inferrable, a JSON for mappings
|
||||
and the original type itself in every other case (ints, floats and so on).
|
||||
Bytes is a special case, not natively handled by click.
|
||||
In this case, returns string when it's not inferrable, a JSON for mappings
|
||||
and the original type itself in every other case (ints, floats and so on).
|
||||
Bytes is a special case, not natively handled by click.
|
||||
|
||||
Args:
|
||||
arg (type): single argument
|
||||
|
||||
Returns:
|
||||
ParamType: click-compatible type
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
# When we don't know the type, we choose 'str'
|
||||
if arg is t.Any: return click_types.convert_type(str)
|
||||
# For containers and nested models, we use JSON
|
||||
if is_container(arg): return JsonType()
|
||||
if lenient_issubclass(arg, bytes): return BytesType()
|
||||
return click_types.convert_type(arg)
|
||||
Args:
|
||||
arg (type): single argument
|
||||
|
||||
Returns:
|
||||
ParamType: click-compatible type
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
# When we don't know the type, we choose 'str'
|
||||
if arg is t.Any: return click_types.convert_type(str)
|
||||
# For containers and nested models, we use JSON
|
||||
if is_container(arg): return JsonType()
|
||||
if lenient_issubclass(arg, bytes): return BytesType()
|
||||
return click_types.convert_type(arg)
|
||||
|
||||
class BytesType(ParamType):
|
||||
name = "bytes"
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
try:
|
||||
return str.encode(value)
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
|
||||
name = "bytes"
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
try:
|
||||
return str.encode(value)
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
|
||||
|
||||
CYGWIN = sys.platform.startswith("cygwin")
|
||||
WIN = sys.platform.startswith("win")
|
||||
if sys.platform.startswith("win") and WIN:
|
||||
|
||||
def _get_argv_encoding() -> str:
|
||||
import locale
|
||||
def _get_argv_encoding() -> str:
|
||||
import locale
|
||||
|
||||
return locale.getpreferredencoding()
|
||||
return locale.getpreferredencoding()
|
||||
|
||||
else:
|
||||
|
||||
def _get_argv_encoding() -> str:
|
||||
return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
|
||||
|
||||
def _get_argv_encoding() -> str:
|
||||
return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
|
||||
|
||||
class CudaValueType(ParamType):
|
||||
name = "cuda"
|
||||
envvar_list_splitter = ","
|
||||
is_composite = True
|
||||
typ = click_types.convert_type(str)
|
||||
name = "cuda"
|
||||
envvar_list_splitter = ","
|
||||
is_composite = True
|
||||
typ = click_types.convert_type(str)
|
||||
|
||||
def split_envvar_value(self, rv: str) -> t.Sequence[str]:
|
||||
var = tuple(i for i in rv.split(self.envvar_list_splitter))
|
||||
if "-1" in var:
|
||||
return var[: var.index("-1")]
|
||||
return var
|
||||
def split_envvar_value(self, rv: str) -> t.Sequence[str]:
|
||||
var = tuple(i for i in rv.split(self.envvar_list_splitter))
|
||||
if "-1" in var:
|
||||
return var[:var.index("-1")]
|
||||
return var
|
||||
|
||||
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
|
||||
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
|
||||
|
||||
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
|
||||
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
|
||||
|
||||
Args:
|
||||
ctx: Invocation context for this command.
|
||||
param: The parameter that is requesting completion.
|
||||
incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
from ..utils import available_devices
|
||||
Args:
|
||||
ctx: Invocation context for this command.
|
||||
param: The parameter that is requesting completion.
|
||||
incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
from ..utils import available_devices
|
||||
|
||||
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
|
||||
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
|
||||
|
||||
return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
|
||||
return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes):
|
||||
enc = _get_argv_encoding()
|
||||
try:
|
||||
value = value.decode(enc)
|
||||
except UnicodeError:
|
||||
fs_enc = sys.getfilesystemencoding()
|
||||
if fs_enc != enc:
|
||||
try:
|
||||
value = value.decode(fs_enc)
|
||||
except UnicodeError:
|
||||
value = value.decode("utf-8", "replace")
|
||||
else:
|
||||
value = value.decode("utf-8", "replace")
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes):
|
||||
enc = _get_argv_encoding()
|
||||
try:
|
||||
value = value.decode(enc)
|
||||
except UnicodeError:
|
||||
fs_enc = sys.getfilesystemencoding()
|
||||
if fs_enc != enc:
|
||||
try:
|
||||
value = value.decode(fs_enc)
|
||||
except UnicodeError:
|
||||
value = value.decode("utf-8", "replace")
|
||||
else:
|
||||
value = value.decode("utf-8", "replace")
|
||||
|
||||
return tuple(self.typ(x, param, ctx) for x in value.split(","))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""CUDA is a click.STRING extension."""
|
||||
return "STRING"
|
||||
return tuple(self.typ(x, param, ctx) for x in value.split(","))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""CUDA is a click.STRING extension."""
|
||||
return "STRING"
|
||||
|
||||
CUDA = CudaValueType()
|
||||
|
||||
|
||||
class JsonType(ParamType):
|
||||
name = "json"
|
||||
name = "json"
|
||||
|
||||
def __init__(self, should_load: bool = True) -> None:
|
||||
"""Support JSON type for click.ParamType.
|
||||
def __init__(self, should_load: bool = True) -> None:
|
||||
"""Support JSON type for click.ParamType.
|
||||
|
||||
Args:
|
||||
should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
|
||||
"""
|
||||
super().__init__()
|
||||
self.should_load = should_load
|
||||
Args:
|
||||
should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
|
||||
"""
|
||||
super().__init__()
|
||||
self.should_load = should_load
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
from . import LazyType
|
||||
if LazyType[t.Mapping[str, str]](t.Mapping[str, str]).isinstance(value) or not self.should_load: return value
|
||||
try: return orjson.loads(value)
|
||||
except orjson.JSONDecodeError as exc: self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
from . import LazyType
|
||||
if LazyType[t.Mapping[str, str]](t.Mapping[str, str]).isinstance(value) or not self.should_load: return value
|
||||
try:
|
||||
return orjson.loads(value)
|
||||
except orjson.JSONDecodeError as exc:
|
||||
self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)
|
||||
|
||||
@@ -19,27 +19,24 @@ from ..utils import DummyMetaclass
|
||||
from ..utils import require_backends
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models.auto.factory import _LazyAutoMapping
|
||||
from ..models.auto.factory import _LazyAutoMapping
|
||||
|
||||
class FlaxFlanT5(metaclass=DummyMetaclass):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["flax"])
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["flax"])
|
||||
|
||||
class FlaxOPT(metaclass=DummyMetaclass):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["flax"])
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["flax"])
|
||||
|
||||
class AutoFlaxLLM(metaclass=DummyMetaclass):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["flax"])
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["flax"])
|
||||
|
||||
MODEL_FLAX_MAPPING = t.cast("_LazyAutoMapping", None)
|
||||
|
||||
@@ -19,14 +19,13 @@ from ..utils import DummyMetaclass
|
||||
from ..utils import require_backends
|
||||
|
||||
class ChatGLM(metaclass=DummyMetaclass):
|
||||
_backends = ["torch", "cpm_kernels"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["torch", "cpm_kernels"])
|
||||
_backends = ["torch", "cpm_kernels"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["torch", "cpm_kernels"])
|
||||
|
||||
class Baichuan(metaclass=DummyMetaclass):
|
||||
_backends = ["torch", "cpm_kernels"]
|
||||
_backends = ["torch", "cpm_kernels"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["torch", "cpm_kernels"])
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["torch", "cpm_kernels"])
|
||||
|
||||
@@ -19,7 +19,7 @@ from ..utils import DummyMetaclass
|
||||
from ..utils import require_backends
|
||||
|
||||
class Falcon(metaclass=DummyMetaclass):
|
||||
_backends = ["torch", "einops"]
|
||||
_backends = ["torch", "einops"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["torch", "einops"])
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["torch", "einops"])
|
||||
|
||||
@@ -19,7 +19,7 @@ from ..utils import DummyMetaclass
|
||||
from ..utils import require_backends
|
||||
|
||||
class MPT(metaclass=DummyMetaclass):
|
||||
_backends = ["torch", "triton"]
|
||||
_backends = ["torch", "triton"]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["torch"])
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
require_backends(self, ["torch"])
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user