mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-24 07:17:53 -05:00
chore(releases): remove deadcode
Signed-off-by: Aaron Pham (mbp16) <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ from openllm_core._typing_compat import (
|
||||
get_literal_args,
|
||||
TypedDict,
|
||||
)
|
||||
from openllm_cli import termui
|
||||
from . import _termui as termui
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
|
||||
82
openllm-python/src/_openllm_tiny/_termui.py
Normal file
82
openllm-python/src/_openllm_tiny/_termui.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
import enum, functools, logging, os, typing as t
|
||||
import click, inflection, orjson
|
||||
from openllm_core._typing_compat import DictStrAny, TypedDict
|
||||
from openllm_core.utils import get_debug_mode
|
||||
|
||||
logger = logging.getLogger('openllm')
|
||||
|
||||
|
||||
class Level(enum.IntEnum):
|
||||
NOTSET = logging.DEBUG
|
||||
DEBUG = logging.DEBUG
|
||||
INFO = logging.INFO
|
||||
WARNING = logging.WARNING
|
||||
ERROR = logging.ERROR
|
||||
CRITICAL = logging.CRITICAL
|
||||
|
||||
@property
|
||||
def color(self) -> str | None:
|
||||
return {
|
||||
Level.NOTSET: None,
|
||||
Level.DEBUG: 'cyan',
|
||||
Level.INFO: 'green',
|
||||
Level.WARNING: 'yellow',
|
||||
Level.ERROR: 'red',
|
||||
Level.CRITICAL: 'red',
|
||||
}[self]
|
||||
|
||||
@classmethod
|
||||
def from_logging_level(cls, level: int) -> Level:
|
||||
return {
|
||||
logging.DEBUG: Level.DEBUG,
|
||||
logging.INFO: Level.INFO,
|
||||
logging.WARNING: Level.WARNING,
|
||||
logging.ERROR: Level.ERROR,
|
||||
logging.CRITICAL: Level.CRITICAL,
|
||||
}[level]
|
||||
|
||||
|
||||
class JsonLog(TypedDict):
|
||||
log_level: Level
|
||||
content: str
|
||||
|
||||
|
||||
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
|
||||
if get_debug_mode():
|
||||
echo(content, fg=fg)
|
||||
else:
|
||||
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
|
||||
|
||||
|
||||
warning = functools.partial(log, level=Level.WARNING)
|
||||
error = functools.partial(log, level=Level.ERROR)
|
||||
critical = functools.partial(log, level=Level.CRITICAL)
|
||||
debug = functools.partial(log, level=Level.DEBUG)
|
||||
info = functools.partial(log, level=Level.INFO)
|
||||
notset = functools.partial(log, level=Level.NOTSET)
|
||||
|
||||
|
||||
def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
|
||||
if json:
|
||||
text = orjson.loads(text)
|
||||
if 'content' in text and 'log_level' in text:
|
||||
content = text['content']
|
||||
fg = Level.from_logging_level(text['log_level']).color
|
||||
else:
|
||||
content = orjson.dumps(text).decode()
|
||||
fg = Level.INFO.color if not get_debug_mode() else Level.DEBUG.color
|
||||
else:
|
||||
content = t.cast(str, text)
|
||||
attrs['fg'] = fg
|
||||
|
||||
(click.echo if not _with_style else click.secho)(content, **attrs)
|
||||
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {
|
||||
'help_option_names': ['-h', '--help'],
|
||||
'max_content_width': COLUMNS,
|
||||
'token_normalize_func': inflection.underscore,
|
||||
}
|
||||
__all__ = ['COLUMNS', 'CONTEXT_SETTINGS', 'Level', 'critical', 'debug', 'echo', 'error', 'info', 'log', 'warning']
|
||||
@@ -27,14 +27,11 @@ __lazy = utils.LazyModule( # NOTE: update this to sys.modules[__name__] once my
|
||||
'exceptions': [],
|
||||
'client': ['HTTPClient', 'AsyncHTTPClient'],
|
||||
'bundle': [],
|
||||
'testing': [],
|
||||
'utils': ['api'],
|
||||
'entrypoints': ['mount_entrypoints'],
|
||||
'serialisation': ['ggml', 'transformers', 'vllm'],
|
||||
'_llm': ['LLM'],
|
||||
'_deprecated': ['Runner'],
|
||||
'_runners': ['runner'],
|
||||
'_quantisation': ['infer_quantisation_config'],
|
||||
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
|
||||
},
|
||||
extra_objects={'COMPILED': COMPILED},
|
||||
@@ -44,7 +41,7 @@ __all__, __dir__ = __lazy.__all__, __lazy.__dir__
|
||||
_BREAKING_INTERNAL = ['_service', '_service_vars']
|
||||
_NEW_IMPL = ['LLM', *_BREAKING_INTERNAL]
|
||||
|
||||
if (_BENTOML_VERSION := utils.pkg.pkg_version_info('bentoml')) > (1, 2):
|
||||
if utils.pkg.pkg_version_info('bentoml') > (1, 2):
|
||||
import _openllm_tiny as _tiny
|
||||
else:
|
||||
_tiny = None
|
||||
@@ -58,7 +55,7 @@ def __getattr__(name: str) -> _t.Any:
|
||||
f'"{name}" is an internal implementation and considered breaking with older OpenLLM. Please migrate your code if you depend on this.'
|
||||
)
|
||||
_warnings.warn(
|
||||
f'"{name}" is considered deprecated implementation and will be removed in the future. Make sure to upgrade to OpenLLM 0.5.x',
|
||||
f'"{name}" is considered deprecated implementation and could be breaking. See https://github.com/bentoml/OpenLLM for more information on upgrading instruction.',
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from . import (
|
||||
exceptions as exceptions,
|
||||
serialisation as serialisation,
|
||||
utils as utils,
|
||||
entrypoints as entrypoints,
|
||||
)
|
||||
from .serialisation import ggml as ggml, transformers as transformers, vllm as vllm
|
||||
from ._deprecated import Runner as Runner
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
import importlib
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
_import_structure = {'openai': [], 'hf': []}
|
||||
|
||||
|
||||
def mount_entrypoints(svc, llm):
|
||||
for module_name in _import_structure:
|
||||
svc = importlib.import_module(f'.{module_name}', __name__).mount_to_svc(svc, llm)
|
||||
return svc
|
||||
|
||||
|
||||
__lazy = LazyModule(
|
||||
__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints}
|
||||
)
|
||||
__all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Entrypoint for all third-party apps.
|
||||
|
||||
Currently support OpenAI compatible API.
|
||||
|
||||
Each module should implement the following API:
|
||||
|
||||
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from _bentoml_sdk import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from . import hf as hf, openai as openai
|
||||
from .._llm import LLM
|
||||
|
||||
def mount_entrypoints(svc: Service[Any], llm: LLM[M, T]) -> Service: ...
|
||||
@@ -1,641 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
from starlette.routing import Host, Mount, Route
|
||||
from starlette.schemas import EndpointInfo, SchemaGenerator
|
||||
|
||||
from openllm_core.utils import first_not_none
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import pydantic
|
||||
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODELS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >
|
||||
List and describe the various models available in the API.
|
||||
|
||||
You can refer to the available supported models with `openllm models` for more
|
||||
information.
|
||||
operationId: openai__list_models
|
||||
produces:
|
||||
- application/json
|
||||
summary: Describes a model offering that can be used with the API.
|
||||
tags:
|
||||
- OpenAI
|
||||
x-bentoml-name: list_models
|
||||
responses:
|
||||
200:
|
||||
description: The Model object
|
||||
content:
|
||||
application/json:
|
||||
example:
|
||||
object: 'list'
|
||||
data:
|
||||
- id: __model_id__
|
||||
object: model
|
||||
created: 1686935002
|
||||
owned_by: 'na'
|
||||
schema:
|
||||
$ref: '#/components/schemas/ModelList'
|
||||
"""
|
||||
CHAT_COMPLETIONS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a list of messages comprising a conversation, the model will return a
|
||||
response.
|
||||
operationId: openai__chat_completions
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
- OpenAI
|
||||
x-bentoml-name: create_chat_completions
|
||||
summary: Creates a model response for the given chat conversation.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
examples:
|
||||
one-shot:
|
||||
summary: One-shot input example
|
||||
value:
|
||||
messages: __chat_messages__
|
||||
model: __model_id__
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
top_p: 0.43
|
||||
n: 1
|
||||
stream: false
|
||||
chat_template: __chat_template__
|
||||
add_generation_prompt: __add_generation_prompt__
|
||||
echo: false
|
||||
streaming:
|
||||
summary: Streaming input example
|
||||
value:
|
||||
messages:
|
||||
- role: system
|
||||
content: You are a helpful assistant.
|
||||
- role: user
|
||||
content: Hello, I'm looking for a chatbot that can help me with my work.
|
||||
model: __model_id__
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
top_p: 0.43
|
||||
n: 1
|
||||
stream: true
|
||||
stop:
|
||||
- "<|endoftext|>"
|
||||
chat_template: __chat_template__
|
||||
add_generation_prompt: __add_generation_prompt__
|
||||
echo: false
|
||||
schema:
|
||||
$ref: '#/components/schemas/ChatCompletionRequest'
|
||||
responses:
|
||||
200:
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ChatCompletionResponse'
|
||||
examples:
|
||||
streaming:
|
||||
summary: Streaming output example
|
||||
value: >
|
||||
{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
one-shot:
|
||||
summary: One-shot output example
|
||||
value: >
|
||||
{"id": "chatcmpl-123", "object": "chat.completion", "created": 1677652288, "model": "gpt-3.5-turbo-0613", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello there, how may I assist you today?"}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}}
|
||||
404:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
wrong-model:
|
||||
summary: Wrong model
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Model 'meta-llama--Llama-2-13b-chat-hf' does not exists. Try 'GET /v1/models' to see available models.\\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 404
|
||||
}
|
||||
}
|
||||
description: NotFound
|
||||
500:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
invalid-parameters:
|
||||
summary: Invalid parameters
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "`top_p` has to be a float > 0 and < 1, but is 4.0",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 500
|
||||
}
|
||||
}
|
||||
description: Internal Server Error
|
||||
400:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
invalid-json:
|
||||
summary: Invalid JSON sent
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Invalid JSON input received (Check server log).",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 400
|
||||
}
|
||||
}
|
||||
invalid-prompt:
|
||||
summary: Invalid prompt
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Please provide a prompt.",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 400
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
"""
|
||||
COMPLETIONS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position. We recommend most users use our Chat completions API.
|
||||
operationId: openai__completions
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
- OpenAI
|
||||
x-bentoml-name: create_completions
|
||||
summary: Creates a completion for the provided prompt and parameters.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CompletionRequest'
|
||||
examples:
|
||||
one-shot:
|
||||
summary: One-shot input example
|
||||
value:
|
||||
prompt: This is a test
|
||||
model: __model_id__
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
logprobs: null
|
||||
top_p: 0.43
|
||||
n: 1
|
||||
stream: false
|
||||
streaming:
|
||||
summary: Streaming input example
|
||||
value:
|
||||
prompt: This is a test
|
||||
model: __model_id__
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
top_p: 0.43
|
||||
logprobs: null
|
||||
n: 1
|
||||
stream: true
|
||||
stop:
|
||||
- "<|endoftext|>"
|
||||
responses:
|
||||
200:
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CompletionResponse'
|
||||
examples:
|
||||
one-shot:
|
||||
summary: One-shot output example
|
||||
value:
|
||||
id: cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7
|
||||
object: text_completion
|
||||
created: 1589478378
|
||||
model: VAR_model_id
|
||||
choices:
|
||||
- text: This is indeed a test
|
||||
index: 0
|
||||
logprobs: null
|
||||
finish_reason: length
|
||||
usage:
|
||||
prompt_tokens: 5
|
||||
completion_tokens: 7
|
||||
total_tokens: 12
|
||||
streaming:
|
||||
summary: Streaming output example
|
||||
value:
|
||||
id: cmpl-7iA7iJjj8V2zOkCGvWF2hAkDWBQZe
|
||||
object: text_completion
|
||||
created: 1690759702
|
||||
choices:
|
||||
- text: This
|
||||
index: 0
|
||||
logprobs: null
|
||||
finish_reason: null
|
||||
model: gpt-3.5-turbo-instruct
|
||||
404:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
wrong-model:
|
||||
summary: Wrong model
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Model 'meta-llama--Llama-2-13b-chat-hf' does not exists. Try 'GET /v1/models' to see available models.\\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 404
|
||||
}
|
||||
}
|
||||
description: NotFound
|
||||
500:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
invalid-parameters:
|
||||
summary: Invalid parameters
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "`top_p` has to be a float > 0 and < 1, but is 4.0",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 500
|
||||
}
|
||||
}
|
||||
description: Internal Server Error
|
||||
400:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
invalid-json:
|
||||
summary: Invalid JSON sent
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Invalid JSON input received (Check server log).",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 400
|
||||
}
|
||||
}
|
||||
invalid-prompt:
|
||||
summary: Invalid prompt
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Please provide a prompt.",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 400
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
"""
|
||||
HF_AGENT_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: Generate instruction for given HF Agent chain for all OpenLLM supported models.
|
||||
operationId: hf__agent
|
||||
summary: Generate instruction for given HF Agent.
|
||||
tags:
|
||||
- HF
|
||||
x-bentoml-name: hf_agent
|
||||
produces:
|
||||
- application/json
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AgentRequest'
|
||||
example:
|
||||
inputs: "Is the following `text` positive or negative?"
|
||||
parameters:
|
||||
text: "This is a positive text."
|
||||
stop: []
|
||||
required: true
|
||||
responses:
|
||||
200:
|
||||
description: Successfull generated instruction.
|
||||
content:
|
||||
application/json:
|
||||
example:
|
||||
- generated_text: "This is a generated instruction."
|
||||
schema:
|
||||
$ref: '#/components/schemas/AgentResponse'
|
||||
400:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Bad Request
|
||||
500:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
"""
|
||||
HF_ADAPTERS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: Return current list of adapters for given LLM.
|
||||
operationId: hf__adapters_map
|
||||
produces:
|
||||
- application/json
|
||||
summary: Describes a model offering that can be used with the API.
|
||||
tags:
|
||||
- HF
|
||||
x-bentoml-name: hf_adapters
|
||||
responses:
|
||||
200:
|
||||
description: Return list of LoRA adapters.
|
||||
content:
|
||||
application/json:
|
||||
example:
|
||||
aarnphm/opt-6-7b-quotes:
|
||||
adapter_name: default
|
||||
adapter_type: LORA
|
||||
aarnphm/opt-6-7b-dolly:
|
||||
adapter_name: dolly
|
||||
adapter_type: LORA
|
||||
500:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
"""
|
||||
COHERE_GENERATE_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a prompt, the model will return one or more predicted completions, and
|
||||
can also return the probabilities of alternative tokens at each position.
|
||||
operationId: cohere__generate
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
- Cohere
|
||||
x-bentoml-name: cohere_generate
|
||||
summary: Creates a completion for the provided prompt and parameters.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CohereGenerateRequest'
|
||||
examples:
|
||||
one-shot:
|
||||
summary: One-shot input example
|
||||
value:
|
||||
prompt: This is a test
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
p: 0.43
|
||||
k: 12
|
||||
num_generations: 2
|
||||
stream: false
|
||||
streaming:
|
||||
summary: Streaming input example
|
||||
value:
|
||||
prompt: This is a test
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
p: 0.43
|
||||
k: 12
|
||||
num_generations: 2
|
||||
stream: true
|
||||
stop_sequences:
|
||||
- "<|endoftext|>"
|
||||
"""
|
||||
COHERE_CHAT_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a list of messages comprising a conversation, the model will return a response.
|
||||
operationId: cohere__chat
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
- Cohere
|
||||
x-bentoml-name: cohere_chat
|
||||
summary: Creates a model response for the given chat conversation.
|
||||
"""
|
||||
|
||||
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
|
||||
|
||||
|
||||
def apply_schema(func, **attrs):
|
||||
for k, v in attrs.items():
|
||||
func.__doc__ = func.__doc__.replace(k, v)
|
||||
return func
|
||||
|
||||
|
||||
def add_schema_definitions(func):
|
||||
append_str = _SCHEMAS.get(func.__name__.lower(), '')
|
||||
if not append_str:
|
||||
return func
|
||||
if func.__doc__ is None:
|
||||
func.__doc__ = ''
|
||||
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
|
||||
return func
|
||||
|
||||
|
||||
class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
def get_endpoints(self, routes):
|
||||
endpoints_info = []
|
||||
for route in routes:
|
||||
if isinstance(route, (Mount, Host)):
|
||||
routes = route.routes or []
|
||||
path = self._remove_converter(route.path) if isinstance(route, Mount) else ''
|
||||
sub_endpoints = [
|
||||
EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func)
|
||||
for sub_endpoint in self.get_endpoints(routes)
|
||||
]
|
||||
endpoints_info.extend(sub_endpoints)
|
||||
elif not isinstance(route, Route) or not route.include_in_schema:
|
||||
continue
|
||||
elif (
|
||||
inspect.isfunction(route.endpoint)
|
||||
or inspect.ismethod(route.endpoint)
|
||||
or isinstance(route.endpoint, functools.partial)
|
||||
):
|
||||
endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint
|
||||
path = self._remove_converter(route.path)
|
||||
for method in route.methods or ['GET']:
|
||||
if method == 'HEAD':
|
||||
continue
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), endpoint))
|
||||
else:
|
||||
path = self._remove_converter(route.path)
|
||||
for method in ['get', 'post', 'put', 'patch', 'delete', 'options']:
|
||||
if not hasattr(route.endpoint, method):
|
||||
continue
|
||||
func = getattr(route.endpoint, method)
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), func))
|
||||
return endpoints_info
|
||||
|
||||
def get_schema(self, routes, mount_path=None):
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault('paths', {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
if mount_path:
|
||||
mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
|
||||
|
||||
for endpoint in endpoints_info:
|
||||
parsed = self.parse_docstring(endpoint.func)
|
||||
if not parsed:
|
||||
continue
|
||||
|
||||
path = endpoint.path if mount_path is None else mount_path + endpoint.path
|
||||
if path not in schema['paths']:
|
||||
schema['paths'][path] = {}
|
||||
schema['paths'][path][endpoint.http_method] = parsed
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_generator(title, components=None, tags=None, inject=True):
|
||||
base_schema = {'info': {'title': title, 'version': API_VERSION}, 'version': OPENAPI_VERSION}
|
||||
if components and inject:
|
||||
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
|
||||
if tags is not None and tags and inject:
|
||||
base_schema['tags'] = tags
|
||||
return OpenLLMSchemaGenerator(base_schema)
|
||||
|
||||
|
||||
def component_schema_generator(attr_cls: pydantic.BaseModel, description=None):
|
||||
schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
schema['description'] = first_not_none(
|
||||
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
|
||||
)
|
||||
for name, field in attr_cls.model_fields.items():
|
||||
attr_type = field.annotation
|
||||
origin_type = t.get_origin(attr_type)
|
||||
args_type = t.get_args(attr_type)
|
||||
|
||||
# Map Python types to OpenAPI schema types
|
||||
if isinstance(attr_type, str):
|
||||
schema_type = 'string'
|
||||
elif isinstance(attr_type, int):
|
||||
schema_type = 'integer'
|
||||
elif isinstance(attr_type, float):
|
||||
schema_type = 'number'
|
||||
elif isinstance(attr_type, bool):
|
||||
schema_type = 'boolean'
|
||||
elif origin_type is list or origin_type is tuple:
|
||||
schema_type = 'array'
|
||||
elif origin_type is dict:
|
||||
schema_type = 'object'
|
||||
# Assuming string keys for simplicity, and handling Any type for values
|
||||
prop_schema = {'type': 'object', 'additionalProperties': True if args_type[1] is t.Any else {'type': 'string'}}
|
||||
elif attr_type == t.Optional[str]:
|
||||
schema_type = 'string'
|
||||
elif origin_type is t.Union and t.Any in args_type:
|
||||
schema_type = 'object'
|
||||
prop_schema = {'type': 'object', 'additionalProperties': True}
|
||||
else:
|
||||
schema_type = 'string'
|
||||
|
||||
if 'prop_schema' not in locals():
|
||||
prop_schema = {'type': schema_type}
|
||||
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
|
||||
prop_schema['default'] = field.default
|
||||
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
|
||||
schema['required'].append(name)
|
||||
schema['properties'][name] = prop_schema
|
||||
locals().pop('prop_schema', None)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
_SimpleSchema = types.new_class(
|
||||
'_SimpleSchema',
|
||||
(object,),
|
||||
{},
|
||||
lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}),
|
||||
)
|
||||
|
||||
|
||||
def append_schemas(svc, generated_schema, tags_order='prepend', inject=True):
|
||||
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
|
||||
from bentoml._internal.service.openapi.specification import OpenAPISpecification
|
||||
|
||||
if not inject:
|
||||
return svc
|
||||
|
||||
svc_schema = svc.openapi_spec
|
||||
if isinstance(svc_schema, (OpenAPISpecification, _SimpleSchema)):
|
||||
svc_schema = svc_schema.asdict()
|
||||
if 'tags' in generated_schema:
|
||||
if tags_order == 'prepend':
|
||||
svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
|
||||
elif tags_order == 'append':
|
||||
svc_schema['tags'].extend(generated_schema['tags'])
|
||||
else:
|
||||
raise ValueError(f'Invalid tags_order: {tags_order}')
|
||||
if 'components' in generated_schema:
|
||||
svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
|
||||
svc_schema['paths'].update(generated_schema['paths'])
|
||||
|
||||
# HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
from bentoml._internal.service import openapi
|
||||
|
||||
def _generate_spec(svc, openapi_version=OPENAPI_VERSION):
|
||||
return _SimpleSchema(svc_schema)
|
||||
|
||||
def asdict(self):
|
||||
return svc_schema
|
||||
|
||||
openapi.generate_spec = _generate_spec
|
||||
OpenAPISpecification.asdict = asdict
|
||||
return svc
|
||||
@@ -1,29 +0,0 @@
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Type
|
||||
|
||||
from attr import AttrsInstance
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.schemas import EndpointInfo
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
class OpenLLMSchemaGenerator:
|
||||
base_schema: Dict[str, Any]
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: ...
|
||||
def get_schema(self, routes: list[BaseRoute], mount_path: Optional[str] = ...) -> Dict[str, Any]: ...
|
||||
def parse_docstring(self, func_or_method: Callable[P, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
def apply_schema(func: Callable[P, Any], **attrs: Any) -> Callable[P, Any]: ...
|
||||
def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ...
|
||||
def append_schemas(
|
||||
svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...
|
||||
) -> Service: ...
|
||||
def component_schema_generator(attr_cls: Type[AttrsInstance], description: Optional[str] = ...) -> Dict[str, Any]: ...
|
||||
def get_generator(
|
||||
title: str,
|
||||
components: Optional[List[Type[AttrsInstance]]] = ...,
|
||||
tags: Optional[List[Dict[str, Any]]] = ...,
|
||||
inject: bool = ...,
|
||||
) -> OpenLLMSchemaGenerator: ...
|
||||
@@ -1,63 +0,0 @@
|
||||
import functools, logging
|
||||
from http import HTTPStatus
|
||||
import orjson
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
from openllm_core.utils import converter
|
||||
from ._openapi import add_schema_definitions, append_schemas, get_generator
|
||||
from ..protocol.hf import AgentRequest, AgentResponse, HFErrorResponse
|
||||
|
||||
schemas = get_generator(
|
||||
'hf',
|
||||
components=[AgentRequest, AgentResponse, HFErrorResponse],
|
||||
tags=[
|
||||
{
|
||||
'name': 'HF',
|
||||
'description': 'HF integration, including Agent and others schema endpoints.',
|
||||
'externalDocs': 'https://huggingface.co/docs/transformers/main_classes/agent',
|
||||
}
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def mount_to_svc(svc, llm):
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
],
|
||||
)
|
||||
mount_path = '/hf'
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
|
||||
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(
|
||||
converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)),
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
async def hf_agent(req, llm):
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), AgentRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
|
||||
stop = request.parameters.pop('stop', [])
|
||||
try:
|
||||
result = await llm.generate(request.inputs, stop=stop, **request.parameters)
|
||||
return JSONResponse(
|
||||
converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error('Error while generating: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')
|
||||
@@ -1,14 +0,0 @@
|
||||
from http import HTTPStatus
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from .._llm import LLM
|
||||
|
||||
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
|
||||
async def hf_agent(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
def hf_adapters(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
@@ -1,448 +0,0 @@
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from http import HTTPStatus
|
||||
|
||||
import orjson
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from openllm_core.utils import converter, gen_random_uuid
|
||||
|
||||
from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator
|
||||
from openllm_core.protocol.openai import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
Delta,
|
||||
ErrorResponse,
|
||||
LogProbs,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
schemas = get_generator(
|
||||
'openai',
|
||||
components=[
|
||||
ErrorResponse,
|
||||
ModelList,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionStreamResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionStreamResponse,
|
||||
],
|
||||
tags=[
|
||||
{
|
||||
'name': 'OpenAI',
|
||||
'description': 'OpenAI Compatible API support',
|
||||
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object',
|
||||
}
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def jsonify_attr(obj):
|
||||
return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(
|
||||
{
|
||||
'error': converter.unstructure(
|
||||
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
|
||||
)
|
||||
},
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
|
||||
async def check_model(request, model): # noqa
|
||||
if request.model == model:
|
||||
return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
)
|
||||
|
||||
|
||||
def create_logprobs(token_ids, top_logprobs, num_output_top_logprobs=None, initial_text_offset=0, *, llm):
|
||||
# Create OpenAI-style logprobs.
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
token_logprob = None
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id]
|
||||
token = llm.tokenizer.convert_ids_to_tokens(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append(
|
||||
{llm.tokenizer.convert_ids_to_tokens(i): p for i, p in step_top_logprobs.items()}
|
||||
if step_top_logprobs
|
||||
else None
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
def mount_to_svc(svc, llm):
|
||||
list_models.__doc__ = list_models.__doc__.replace('__model_id__', llm.llm_type)
|
||||
completions.__doc__ = completions.__doc__.replace('__model_id__', llm.llm_type)
|
||||
chat_completions.__doc__ = chat_completions.__doc__.replace('__model_id__', llm.llm_type)
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route(
|
||||
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
|
||||
),
|
||||
Route(
|
||||
'/completions',
|
||||
functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm),
|
||||
methods=['POST'],
|
||||
),
|
||||
Route(
|
||||
'/chat/completions',
|
||||
functools.partial(
|
||||
apply_schema(
|
||||
chat_completions,
|
||||
__model_id__=llm.llm_type,
|
||||
__chat_template__=orjson.dumps(llm.config.chat_template).decode(),
|
||||
__chat_messages__=orjson.dumps(llm.config.chat_messages).decode(),
|
||||
__add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False),
|
||||
),
|
||||
llm=llm,
|
||||
),
|
||||
methods=['POST'],
|
||||
),
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
],
|
||||
)
|
||||
svc.mount_asgi_app(app, path='/v1')
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path='/v1'))
|
||||
|
||||
|
||||
# GET /v1/models
|
||||
@add_schema_definitions
|
||||
def list_models(_, llm):
|
||||
return JSONResponse(
|
||||
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@add_schema_definitions
|
||||
async def chat_completions(req, llm):
|
||||
# TODO: Check for length based on model context_length
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), ChatCompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received chat completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
prompt = llm.tokenizer.apply_chat_template(
|
||||
request.messages,
|
||||
tokenize=False,
|
||||
chat_template=request.chat_template if request.chat_template != 'None' else None,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
def get_role() -> str:
|
||||
return (
|
||||
request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant'
|
||||
) # TODO: Support custom role here.
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index, text, finish_reason=None, usage=None):
|
||||
response = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
if usage is not None:
|
||||
response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
|
||||
async def completion_stream_generator():
|
||||
# first chunk with role
|
||||
role = get_role()
|
||||
for i in range(config['n']):
|
||||
yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n'
|
||||
|
||||
if request.echo:
|
||||
last_message, last_content = request.messages[-1], ''
|
||||
if last_message.get('content') and last_message.get('role') == role:
|
||||
last_content = last_message['content']
|
||||
if last_content:
|
||||
for i in range(config['n']):
|
||||
yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n'
|
||||
|
||||
previous_num_tokens = [0] * config['n']
|
||||
finish_reason_sent = [False] * config['n']
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
if finish_reason_sent[output.index]:
|
||||
continue
|
||||
yield f'data: {create_stream_response_json(output.index, output.text)}\n\n'
|
||||
previous_num_tokens[output.index] += len(output.token_ids)
|
||||
if output.finish_reason is not None:
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
|
||||
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason, usage)}\n\n'
|
||||
finish_reason_sent[output.index] = True
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.model_copy(
|
||||
update=dict(
|
||||
outputs=[
|
||||
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
role = get_role()
|
||||
choices = [
|
||||
ChatCompletionResponseChoice(
|
||||
index=output.index, message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason
|
||||
)
|
||||
for output in final_result.outputs
|
||||
]
|
||||
if request.echo:
|
||||
last_message, last_content = request.messages[-1], ''
|
||||
if last_message.get('content') and last_message.get('role') == role:
|
||||
last_content = last_message['content']
|
||||
for choice in choices:
|
||||
full_message = last_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
|
||||
)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
# POST /v1/completions
|
||||
@add_schema_definitions
|
||||
async def completions(req, llm):
|
||||
# TODO: Check for length based on model context_length
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received legacy completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
||||
echo_without_generation = request.echo and request.max_tokens == 0
|
||||
if echo_without_generation:
|
||||
request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back.
|
||||
|
||||
if request.suffix is not None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
if not request.prompt:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('cmpl')
|
||||
created_time = int(time.monotonic())
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
# best_of != n then we don't stream
|
||||
# TODO: support use_beam_search
|
||||
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
|
||||
|
||||
def create_stream_response_json(index, text, logprobs=None, finish_reason=None, usage=None):
|
||||
response = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)],
|
||||
)
|
||||
if usage:
|
||||
response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
|
||||
async def completion_stream_generator():
|
||||
previous_num_tokens = [0] * config['n']
|
||||
previous_texts = [''] * config['n']
|
||||
previous_echo = [False] * config['n']
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
delta_text = output.text
|
||||
token_ids = output.token_ids
|
||||
logprobs = None
|
||||
top_logprobs = None
|
||||
if request.logprobs is not None:
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i] :]
|
||||
|
||||
if request.echo and not previous_echo[i]:
|
||||
if not echo_without_generation:
|
||||
delta_text = res.prompt + delta_text
|
||||
token_ids = res.prompt_token_ids + token_ids
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs + top_logprobs
|
||||
else:
|
||||
delta_text = res.prompt
|
||||
token_ids = res.prompt_token_ids
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs
|
||||
previous_echo[i] = True
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
output.token_ids,
|
||||
output.logprobs[previous_num_tokens[i] :],
|
||||
request.logprobs,
|
||||
len(previous_texts[i]),
|
||||
llm=llm,
|
||||
)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
previous_texts[i] += output.text
|
||||
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
|
||||
if output.finish_reason is not None:
|
||||
logprobs = LogProbs() if request.logprobs is not None else None
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
|
||||
yield f'data: {create_stream_response_json(i, "", logprobs, output.finish_reason, usage)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.model_copy(
|
||||
update=dict(
|
||||
outputs=[
|
||||
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
choices = []
|
||||
prompt_token_ids = final_result.prompt_token_ids
|
||||
prompt_logprobs = final_result.prompt_logprobs
|
||||
prompt_text = final_result.prompt
|
||||
for output in final_result.outputs:
|
||||
logprobs = None
|
||||
if request.logprobs is not None:
|
||||
if not echo_without_generation:
|
||||
token_ids, top_logprobs = output.token_ids, output.logprobs
|
||||
if request.echo:
|
||||
token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs
|
||||
else:
|
||||
token_ids, top_logprobs = prompt_token_ids, prompt_logprobs
|
||||
logprobs = create_logprobs(token_ids, top_logprobs, request.logprobs, llm=llm)
|
||||
if not echo_without_generation:
|
||||
output_text = output.text
|
||||
if request.echo:
|
||||
output_text = prompt_text + output_text
|
||||
else:
|
||||
output_text = prompt_text
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
||||
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
@@ -1,30 +0,0 @@
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from attr import AttrsInstance
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from .._llm import LLM
|
||||
from ..protocol.openai import ChatCompletionRequest, CompletionRequest, LogProbs
|
||||
|
||||
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def jsonify_attr(obj: AttrsInstance) -> str: ...
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
|
||||
async def check_model(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest], model: str
|
||||
) -> Optional[JSONResponse]: ...
|
||||
def create_logprobs(
|
||||
token_ids: List[int],
|
||||
top_logprobs: List[Dict[int, float]], #
|
||||
num_output_top_logprobs: Optional[int] = ...,
|
||||
initial_text_offset: int = ...,
|
||||
*,
|
||||
llm: LLM[M, T],
|
||||
) -> LogProbs: ...
|
||||
def list_models(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
async def completions(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
@@ -1,366 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import functools, logging, os, typing as t
|
||||
import bentoml, openllm, click, inflection, click_option_group as cog
|
||||
from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from click import shell_completion as sc
|
||||
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._typing_compat import (
|
||||
Concatenate,
|
||||
DictStrAny,
|
||||
LiteralBackend,
|
||||
LiteralSerialisation,
|
||||
ParamSpec,
|
||||
AnyCallable,
|
||||
get_literal_args,
|
||||
)
|
||||
from openllm_core.utils import DEBUG, compose, dantic, resolve_user_filepath
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec('P')
|
||||
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
|
||||
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [
|
||||
sc.CompletionItem(str(it.tag), help='Bento')
|
||||
for it in bentoml.list()
|
||||
if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})
|
||||
]
|
||||
|
||||
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [
|
||||
sc.CompletionItem(inflection.dasherize(it), help='Model')
|
||||
for it in openllm.CONFIG_MAPPING
|
||||
if it.startswith(incomplete)
|
||||
]
|
||||
|
||||
|
||||
def parse_config_options(
|
||||
config: LLMConfig,
|
||||
server_timeout: int,
|
||||
workers_per_resource: float,
|
||||
device: t.Tuple[str, ...] | None,
|
||||
cors: bool,
|
||||
environ: DictStrAny,
|
||||
) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
_bentoml_config_options_opts = [
|
||||
'tracing.sample_rate=1.0',
|
||||
'api_server.max_runner_connections=25',
|
||||
f'runners."llm-{config["start_name"]}-runner".batching.max_batch_size=128',
|
||||
f'api_server.traffic.timeout={server_timeout}',
|
||||
f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
|
||||
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
|
||||
]
|
||||
if device:
|
||||
if len(device) > 1:
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}'
|
||||
for idx, dev in enumerate(device)
|
||||
])
|
||||
else:
|
||||
_bentoml_config_options_opts.append(
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]'
|
||||
)
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend([
|
||||
'api_server.http.cors.enabled=true',
|
||||
'api_server.http.cors.access_control_allow_origins="*"',
|
||||
])
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"'
|
||||
for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
|
||||
])
|
||||
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
|
||||
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
|
||||
if DEBUG:
|
||||
logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
|
||||
return environ
|
||||
|
||||
|
||||
_adapter_mapping_key = 'adapter_map'
|
||||
|
||||
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None:
|
||||
if not value:
|
||||
return None
|
||||
if _adapter_mapping_key not in ctx.params:
|
||||
ctx.params[_adapter_mapping_key] = {}
|
||||
for v in value:
|
||||
adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
|
||||
# try to resolve the full path if users pass in relative,
|
||||
# currently only support one level of resolve path with current directory
|
||||
try:
|
||||
adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
name = adapter_name[0] if len(adapter_name) > 0 else 'default'
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = name
|
||||
return None
|
||||
|
||||
|
||||
def optimization_decorator(fn: FC, *, factory=click, _eager=True) -> FC | list[AnyCallable]:
|
||||
shared = [
|
||||
dtype_option(factory=factory),
|
||||
revision_option(factory=factory), #
|
||||
backend_option(factory=factory),
|
||||
quantize_option(factory=factory), #
|
||||
serialisation_option(factory=factory),
|
||||
]
|
||||
if not _eager:
|
||||
return shared
|
||||
return compose(*shared)(fn)
|
||||
|
||||
|
||||
# NOTE: A list of bentoml option that is not needed for parsing.
|
||||
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
|
||||
# NOTE: production is also deprecated
|
||||
_IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
|
||||
|
||||
|
||||
def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
|
||||
from bentoml_cli.cli import cli
|
||||
|
||||
group = cog.optgroup.group(
|
||||
'Start a HTTP server options', help='Related to serving the model [synonymous to `bentoml serve-http`]'
|
||||
)
|
||||
|
||||
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands['serve']
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
serve_options = [
|
||||
p
|
||||
for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS]
|
||||
if p.name not in _IGNORED_OPTIONS
|
||||
]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
attrs.pop('param_type_name')
|
||||
# name is not a valid args
|
||||
attrs.pop('name')
|
||||
# type can be determine from default value
|
||||
attrs.pop('type')
|
||||
param_decls = (*attrs.pop('opts'), *attrs.pop('secondary_opts'))
|
||||
f = cog.optgroup.option(*param_decls, **attrs)(f)
|
||||
return group(f)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def start_decorator(fn: FC) -> FC:
|
||||
composed = compose(
|
||||
cog.optgroup.group(
|
||||
'LLM Options',
|
||||
help="""The following options are related to running LLM Server as well as optimization options.
|
||||
|
||||
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
|
||||
The following are either in our roadmap or currently being worked on:
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
""",
|
||||
),
|
||||
*optimization_decorator(fn, factory=cog.optgroup, _eager=False),
|
||||
cog.optgroup.option(
|
||||
'--device',
|
||||
type=dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True,
|
||||
),
|
||||
adapter_id_option(factory=cog.optgroup),
|
||||
)
|
||||
|
||||
return composed(fn)
|
||||
|
||||
|
||||
def parse_device_callback(
|
||||
_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None
|
||||
) -> t.Tuple[str, ...] | None:
|
||||
if value is None:
|
||||
return value
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all':
|
||||
return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
|
||||
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
"""General ``@click`` decorator with some sauce.
|
||||
|
||||
This decorator extends the default ``@click.option`` plus a factory option and factory attr to
|
||||
provide type-safe click.option or click.argument wrapper for all compatible factory.
|
||||
"""
|
||||
factory = attrs.pop('factory', click)
|
||||
factory_attr = attrs.pop('attr', 'option')
|
||||
if factory_attr != 'argument':
|
||||
attrs.setdefault('help', 'General option for OpenLLM CLI.')
|
||||
|
||||
def decorator(f: FC | None) -> FC:
|
||||
callback = getattr(factory, factory_attr, None)
|
||||
if callback is None:
|
||||
raise ValueError(f'Factory {factory} has no attribute {factory_attr}.')
|
||||
return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs))
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
cli_option = functools.partial(_click_factory_type, attr='option')
|
||||
cli_argument = functools.partial(_click_factory_type, attr='argument')
|
||||
|
||||
|
||||
def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter',
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--cors/--no-cors',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar='OPENLLM_CORS',
|
||||
show_envvar=True,
|
||||
help='Enable CORS for the server.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
|
||||
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--dtype',
|
||||
type=str,
|
||||
envvar='TORCH_DTYPE',
|
||||
default='auto',
|
||||
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--model-id',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_MODEL_ID',
|
||||
show_envvar=True,
|
||||
help='Optional model_id name or path for (fine-tune) weight.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def revision_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--revision',
|
||||
'--model-version',
|
||||
'model_version',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional model revision to save for this model. It will be inferred automatically from model-id.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--backend',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)),
|
||||
default=None,
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='Runtime to use for both serialisation/inference engine.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]),
|
||||
required=required,
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=str,
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
- ``int8``: ``LLM.int8`` for [8-bit](https://arxiv.org/abs/2208.07339) quantization.
|
||||
|
||||
- ``int4``: ``SpQR`` for [4-bit](https://arxiv.org/abs/2306.03078) quantization.
|
||||
|
||||
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
- ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978)
|
||||
|
||||
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
"""
|
||||
+ (
|
||||
"""
|
||||
> [!NOTE] that this will set the mode for serving within deployment."""
|
||||
if build
|
||||
else ''
|
||||
),
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--serialisation',
|
||||
'--serialization',
|
||||
'serialisation',
|
||||
type=click.Choice(get_literal_args(LiteralSerialisation)),
|
||||
default=None,
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help="""Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
- ``safetensors``: This will use safetensors format, which is synonymous to ``safe_serialization=True``.
|
||||
|
||||
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
|
||||
|
||||
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
||||
""",
|
||||
**attrs,
|
||||
)(f)
|
||||
@@ -1,276 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import itertools, logging, os, re, subprocess, sys, typing as t, bentoml, openllm_core, orjson
|
||||
from simple_di import Provide, inject
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm_core._typing_compat import LiteralSerialisation
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import WARNING_ENV_VAR, codegen, first_not_none, get_disable_warnings, is_vllm_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._typing_compat import LiteralBackend, LiteralQuantise, LiteralString
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _start(
|
||||
model_id: str,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
cors: bool = False,
|
||||
__test__: bool = False,
|
||||
**_: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
"""Python API to start a LLM server. These provides one-to-one mapping to CLI arguments.
|
||||
|
||||
For all additional arguments, pass it as string to ``additional_args``. For example, if you want to
|
||||
pass ``--port 5001``, you can pass ``additional_args=["--port", "5001"]``
|
||||
|
||||
> [!NOTE] This will create a blocking process, so if you use this API, you can create a running sub thread
|
||||
> to start the server instead of blocking the main thread.
|
||||
|
||||
``openllm.start`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI interaction.
|
||||
|
||||
Args:
|
||||
model_id: The model id to start this LLMServer
|
||||
timeout: The server timeout
|
||||
workers_per_resource: Number of workers per resource assigned.
|
||||
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
|
||||
for more information. By default, this is set to 1.
|
||||
|
||||
> [!NOTE] ``--workers-per-resource`` will also accept the following strategies:
|
||||
> - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
> - ``conserved``: This will determine the number of available GPU resources, and only assign
|
||||
> one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is
|
||||
> equivalent to ``--workers-per-resource 0.25``.
|
||||
device: Assign GPU devices (if available) to this LLM. By default, this is set to ``None``. It also accepts 'all'
|
||||
argument to assign all available GPUs to this LLM.
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
- int4: Quantize the model with 4bit (bitsandbytes required)
|
||||
- gptq: Quantize the model with GPTQ (auto-gptq required)
|
||||
cors: Whether to enable CORS for this LLM. By default, this is set to ``False``.
|
||||
adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``.
|
||||
backend: The backend to use for this LLM. By default, this is set to ``pt``.
|
||||
additional_args: Additional arguments to pass to ``openllm start``.
|
||||
"""
|
||||
from .entrypoint import start_command
|
||||
|
||||
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
|
||||
args: list[str] = [model_id]
|
||||
if timeout:
|
||||
args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
args.extend([
|
||||
'--workers-per-resource',
|
||||
str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource,
|
||||
])
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'):
|
||||
args.extend(['--device', ','.join(device)])
|
||||
if quantize:
|
||||
args.extend(['--quantize', str(quantize)])
|
||||
if cors:
|
||||
args.append('--cors')
|
||||
if adapter_map:
|
||||
args.extend(
|
||||
list(
|
||||
itertools.chain.from_iterable([['--adapter-id', f"{k}{':' + v if v else ''}"] for k, v in adapter_map.items()])
|
||||
)
|
||||
)
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if __test__:
|
||||
args.append('--return-process')
|
||||
|
||||
cmd = start_command
|
||||
return cmd.main(args=args, standalone_mode=False)
|
||||
|
||||
|
||||
@inject
|
||||
def _build(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
dockerfile_template: str | None = None,
|
||||
overwrite: bool = False,
|
||||
push: bool = False,
|
||||
force_push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
) -> bentoml.Bento:
|
||||
"""Package a LLM into a BentoLLM.
|
||||
|
||||
The LLM will be built into a BentoService with the following structure:
|
||||
if ``quantize`` is passed, it will instruct the model to be quantized dynamically during serving time.
|
||||
|
||||
``openllm.build`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as ``openllm build`` CLI.
|
||||
|
||||
Args:
|
||||
model_id: The model id to build this BentoLLM
|
||||
model_version: Optional model version for this given LLM
|
||||
bento_version: Optional bento veresion for this given BentoLLM
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
- int4: Quantize the model with 4bit (bitsandbytes required)
|
||||
- gptq: Quantize the model with GPTQ (auto-gptq required)
|
||||
adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``.
|
||||
build_ctx: The build context to use for building BentoLLM. By default, it sets to current directory.
|
||||
enable_features: Additional OpenLLM features to be included with this BentoLLM.
|
||||
dockerfile_template: The dockerfile template to use for building BentoLLM. See https://docs.bentoml.com/en/latest/guides/containerization.html#dockerfile-template.
|
||||
overwrite: Whether to overwrite the existing BentoLLM. By default, this is set to ``False``.
|
||||
push: Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.
|
||||
containerize: Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'.
|
||||
Note that 'containerize' and 'push' are mutually exclusive
|
||||
container_registry: Container registry to choose the base OpenLLM container image to build from. Default to ECR.
|
||||
serialisation: Serialisation for saving models. Default to 'safetensors', which is equivalent to `safe_serialization=True`
|
||||
additional_args: Additional arguments to pass to ``openllm build``.
|
||||
bento_store: Optional BentoStore for saving this BentoLLM. Default to the default BentoML local store.
|
||||
|
||||
Returns:
|
||||
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
args: list[str] = [
|
||||
sys.executable,
|
||||
'-m',
|
||||
'openllm',
|
||||
'build',
|
||||
model_id,
|
||||
'--machine',
|
||||
'--quiet',
|
||||
'--serialisation',
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
]
|
||||
if quantize:
|
||||
args.extend(['--quantize', quantize])
|
||||
if containerize and push:
|
||||
raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push:
|
||||
args.extend(['--push'])
|
||||
if containerize:
|
||||
args.extend(['--containerize'])
|
||||
if build_ctx:
|
||||
args.extend(['--build-ctx', build_ctx])
|
||||
if enable_features:
|
||||
args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite:
|
||||
args.append('--overwrite')
|
||||
if adapter_map:
|
||||
args.extend([f"--adapter-id={k}{':' + v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version:
|
||||
args.extend(['--model-version', model_version])
|
||||
if bento_version:
|
||||
args.extend(['--bento-version', bento_version])
|
||||
if dockerfile_template:
|
||||
args.extend(['--dockerfile-template', dockerfile_template])
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if force_push:
|
||||
args.append('--force-push')
|
||||
|
||||
current_disable_warning = get_disable_warnings()
|
||||
os.environ[WARNING_ENV_VAR] = str(True)
|
||||
try:
|
||||
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("Exception caught while building Bento for '%s'", model_id, exc_info=e)
|
||||
if e.stderr:
|
||||
raise OpenLLMException(e.stderr.decode('utf-8')) from None
|
||||
raise OpenLLMException(str(e)) from None
|
||||
matched = re.match(r'__object__:(\{.*\})$', output.decode('utf-8').strip())
|
||||
if matched is None:
|
||||
raise ValueError(
|
||||
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
)
|
||||
os.environ[WARNING_ENV_VAR] = str(current_disable_warning)
|
||||
try:
|
||||
result = orjson.loads(matched.group(1))
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
) from e
|
||||
return bentoml.get(result['tag'], _bento_store=bento_store)
|
||||
|
||||
|
||||
def _import_model(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: t.Sequence[str] | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""Import a LLM into local store.
|
||||
|
||||
> [!NOTE]
|
||||
> If ``quantize`` is passed, the model weights will be saved as quantized weights. You should
|
||||
> only use this option if you want the weight to be quantized by default. Note that OpenLLM also
|
||||
> support on-demand quantisation during initial startup.
|
||||
|
||||
``openllm.import_model`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI ``openllm import``.
|
||||
|
||||
> [!NOTE]
|
||||
> ``openllm.start`` will automatically invoke ``openllm.import_model`` under the hood.
|
||||
|
||||
Args:
|
||||
model_id: required model id for this given LLM
|
||||
model_version: Optional model version for this given LLM
|
||||
backend: The backend to use for this LLM. By default, this is set to ``pt``.
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
- int4: Quantize the model with 4bit (bitsandbytes required)
|
||||
- gptq: Quantize the model with GPTQ (auto-gptq required)
|
||||
serialisation: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors. Default behaviour is similar to ``safe_serialization=False``.
|
||||
additional_args: Additional arguments to pass to ``openllm import``.
|
||||
|
||||
Returns:
|
||||
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from .entrypoint import import_command
|
||||
|
||||
args = [model_id, '--quiet']
|
||||
if backend is not None:
|
||||
args.extend(['--backend', backend])
|
||||
if model_version is not None:
|
||||
args.extend(['--model-version', str(model_version)])
|
||||
if quantize is not None:
|
||||
args.extend(['--quantize', quantize])
|
||||
if serialisation is not None:
|
||||
args.extend(['--serialisation', serialisation])
|
||||
if additional_args is not None:
|
||||
args.extend(additional_args)
|
||||
return import_command.main(args=args, standalone_mode=False)
|
||||
|
||||
|
||||
def _list_models() -> dict[str, t.Any]:
|
||||
"""List all available models within the local store."""
|
||||
from .entrypoint import models_command
|
||||
|
||||
return models_command.main(args=['--quiet'], standalone_mode=False)
|
||||
|
||||
|
||||
start, build, import_model, list_models = (
|
||||
codegen.gen_sdk(_start),
|
||||
codegen.gen_sdk(_build),
|
||||
codegen.gen_sdk(_import_model),
|
||||
codegen.gen_sdk(_list_models),
|
||||
)
|
||||
__all__ = ['build', 'import_model', 'list_models', 'start']
|
||||
5
openllm-python/src/openllm_cli/entrypoint.py
Normal file
5
openllm-python/src/openllm_cli/entrypoint.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# mainly for backward compatible entrypoint.
|
||||
from _openllm_tiny._entrypoint import cli as cli
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
@@ -1,16 +0,0 @@
|
||||
"""OpenLLM CLI Extension.
|
||||
|
||||
The following directory contains all possible extensions for OpenLLM CLI
|
||||
For adding new extension, just simply name that ext to `<name_ext>.py` and define
|
||||
a ``click.command()`` with the following format:
|
||||
|
||||
```python
|
||||
import click
|
||||
|
||||
@click.command(<name_ext>)
|
||||
...
|
||||
def cli(...): # <- this is important here, it should always name CLI in order for the extension resolver to know how to import this extensions.
|
||||
```
|
||||
|
||||
NOTE: Make sure to keep this file blank such that it won't mess with the import order.
|
||||
"""
|
||||
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import shutil
|
||||
import subprocess
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import psutil
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import bento_complete_envvar, machine_option
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
|
||||
@click.command('dive_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(
|
||||
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
|
||||
if 'bundler' not in bentomodel.info.labels or bentomodel.info.labels['bundler'] != 'openllm.bundle':
|
||||
ctx.fail(
|
||||
f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness."
|
||||
)
|
||||
if machine:
|
||||
return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
if psutil.WINDOWS:
|
||||
subprocess.check_call([shutil.which('dir') or 'dir'], cwd=bentomodel.path)
|
||||
else:
|
||||
subprocess.check_call([shutil.which('ls') or 'ls', '-Rrthla'], cwd=bentomodel.path)
|
||||
ctx.exit(0)
|
||||
@@ -1,53 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.bento.bento import BentoInfo
|
||||
from bentoml._internal.bento.build_config import DockerOptions
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.container.generate import generate_containerfile
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import bento_complete_envvar
|
||||
from openllm_core.utils import converter
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
|
||||
@click.command(
|
||||
'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.'
|
||||
)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str:
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
|
||||
# The logic below are similar to bentoml._internal.container.construct_containerfile
|
||||
with open(bentomodel.path_of('bento.yaml'), 'r') as f:
|
||||
options = BentoInfo.from_yaml_file(f)
|
||||
# NOTE: dockerfile_template is already included in the
|
||||
# Dockerfile inside bento, and it is not relevant to
|
||||
# construct_containerfile. Hence it is safe to set it to None here.
|
||||
# See https://github.com/bentoml/BentoML/issues/3399.
|
||||
docker_attrs = converter.unstructure(options.docker)
|
||||
# NOTE: if users specify a dockerfile_template, we will
|
||||
# save it to /env/docker/Dockerfile.template. This is necessary
|
||||
# for the reconstruction of the Dockerfile.
|
||||
if 'dockerfile_template' in docker_attrs and docker_attrs['dockerfile_template'] is not None:
|
||||
docker_attrs['dockerfile_template'] = 'env/docker/Dockerfile.template'
|
||||
doc = generate_containerfile(
|
||||
docker=DockerOptions(**docker_attrs),
|
||||
build_ctx=bentomodel.path,
|
||||
conda=options.conda,
|
||||
bento_fs=bentomodel._fs,
|
||||
enable_buildkit=True,
|
||||
add_header=True,
|
||||
)
|
||||
termui.echo(doc, fg='white')
|
||||
return bentomodel.path
|
||||
@@ -1,179 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
from bentoml_cli.utils import opt_callback
|
||||
|
||||
import openllm
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import model_complete_envvar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0:
|
||||
raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras:
|
||||
raise KeyError(f'Extra params passed: {extras}')
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
_input_variables: t.Sequence[str] = attr.field(init=False)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._input_variables = default_formatter.extract_template_variables(self.template)
|
||||
|
||||
def with_options(self, **attrs: t.Any) -> PromptTemplate:
|
||||
prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables}
|
||||
o = attr.evolve(self, template=self.template.format(**prompt_variables))
|
||||
_object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template))
|
||||
return o
|
||||
|
||||
def format(self, **attrs: t.Any) -> str:
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables}
|
||||
try:
|
||||
return self.template.format(**prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
|
||||
) from None
|
||||
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('model_id', shell_complete=model_complete_envvar)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option('--prompt-template-file', type=click.File(), default=None)
|
||||
@click.option('--chat-template-file', type=click.File(), default=None)
|
||||
@click.option('--system-message', type=str, default=None)
|
||||
@click.option(
|
||||
'--add-generation-prompt/--no-add-generation-prompt',
|
||||
default=False,
|
||||
help='See https://huggingface.co/docs/transformers/main/chat_templating#what-template-should-i-use. This only applicable if model-id is a HF model_id',
|
||||
)
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_message='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]',
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(
|
||||
ctx: click.Context,
|
||||
/,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
chat_template_file: t.IO[t.Any] | None,
|
||||
system_message: str | None,
|
||||
add_generation_prompt: bool,
|
||||
_memoized: dict[str, t.Any],
|
||||
**_: t.Any,
|
||||
) -> str | None:
|
||||
"""Helpers for generating prompts.
|
||||
|
||||
\b
|
||||
It accepts remote HF model_ids as well as model name passed to `openllm start`.
|
||||
|
||||
If you pass in a HF model_id, then it will use the tokenizer to generate the prompt.
|
||||
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there"
|
||||
```
|
||||
|
||||
If you need change the prompt template, you can create the template file that contains the jina2 template through `--chat-template-file`
|
||||
See https://huggingface.co/docs/transformers/main/chat_templating#templates-for-chat-models for more details.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there" --chat-template-file template.jinja2
|
||||
```
|
||||
|
||||
\b
|
||||
|
||||
If you pass a model name, then it will use OpenLLM configuration to generate the prompt.
|
||||
Note that this is mainly for utilities, as OpenLLM won't use these prompts to format for you.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt mistral "Hello there"
|
||||
"""
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
|
||||
if prompt_template_file and chat_template_file:
|
||||
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
|
||||
|
||||
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
|
||||
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
|
||||
)
|
||||
if model_id in acceptable:
|
||||
logger.warning(
|
||||
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
|
||||
)
|
||||
config = openllm.AutoConfig.for_model(model_id)
|
||||
template = prompt_template_file.read() if prompt_template_file is not None else config.template
|
||||
system_message = system_message or config.system_message
|
||||
|
||||
try:
|
||||
formatted = (
|
||||
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
|
||||
)
|
||||
except RuntimeError as err:
|
||||
logger.debug('Exception caught while formatting prompt: %s', err)
|
||||
ctx.fail(str(err))
|
||||
else:
|
||||
import transformers
|
||||
|
||||
trust_remote_code = openllm.utils.check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
if chat_template_file is not None:
|
||||
chat_template_file = chat_template_file.read()
|
||||
if system_message is None:
|
||||
logger.warning('system-message is not provided, using default infer from the model architecture.\n')
|
||||
for architecture in config.architectures:
|
||||
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
|
||||
system_message = (
|
||||
openllm.AutoConfig.infer_class_from_name(
|
||||
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
|
||||
)
|
||||
.model_construct_env()
|
||||
.system_message
|
||||
)
|
||||
break
|
||||
else:
|
||||
ctx.fail(
|
||||
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
|
||||
)
|
||||
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
|
||||
)
|
||||
|
||||
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
@@ -1,34 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
from openllm_cli import termui
|
||||
|
||||
|
||||
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context) -> None:
|
||||
"""List available bentos built by OpenLLM."""
|
||||
mapping = {
|
||||
k: [
|
||||
{
|
||||
'tag': str(b.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
'models': [
|
||||
{'tag': str(m.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(m.path))}
|
||||
for m in (bentoml.models.get(_.tag) for _ in b.info.models)
|
||||
],
|
||||
}
|
||||
for b in tuple(i for i in bentoml.list() if all(k in i.info.labels for k in {'start_name', 'bundler'}))
|
||||
if b.info.labels['start_name'] == k
|
||||
]
|
||||
for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
}
|
||||
mapping = {k: v for k, v in mapping.items() if v}
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
@@ -1,49 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import model_complete_envvar, model_name_argument
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
|
||||
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False, shell_complete=model_complete_envvar)
|
||||
def cli(model_name: str | None) -> DictStrAny:
|
||||
"""List available models in local store to be used with OpenLLM."""
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
i
|
||||
for i in bentoml.models.list()
|
||||
if 'framework' in i.info.labels
|
||||
and i.info.labels['framework'] == 'openllm'
|
||||
and 'model_name' in i.info.labels
|
||||
and i.info.labels['model_name'] == k
|
||||
]
|
||||
for k in models
|
||||
}
|
||||
if model_name is not None:
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
i
|
||||
for i in v
|
||||
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
|
||||
]
|
||||
for k, v in ids_in_local_store.items()
|
||||
}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {
|
||||
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val]
|
||||
for k, val in ids_in_local_store.items()
|
||||
}
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return local_models
|
||||
@@ -1,108 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import importlib.machinery
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import jupytext
|
||||
import nbformat
|
||||
import yaml
|
||||
|
||||
from openllm_cli import playground, termui
|
||||
from openllm_core.utils import is_jupyter_available, is_jupytext_available, is_notebook_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_notebook_metadata() -> DictStrAny:
|
||||
with open(os.path.join(os.path.dirname(playground.__file__), '_meta.yml'), 'r') as f:
|
||||
content = yaml.safe_load(f)
|
||||
if not all('description' in k for k in content.values()):
|
||||
raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
|
||||
return content
|
||||
|
||||
|
||||
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('output-dir', default=None, required=False)
|
||||
@click.option(
|
||||
'--port',
|
||||
envvar='JUPYTER_PORT',
|
||||
show_envvar=True,
|
||||
show_default=True,
|
||||
default=8888,
|
||||
help='Default port for Jupyter server',
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
"""OpenLLM Playground.
|
||||
|
||||
A collections of notebooks to explore the capabilities of OpenLLM.
|
||||
This includes notebooks for fine-tuning, inference, and more.
|
||||
|
||||
All of the script available in the playground can also be run directly as a Python script:
|
||||
For example:
|
||||
|
||||
\b
|
||||
```bash
|
||||
python -m openllm.playground.falcon_tuned --help
|
||||
```
|
||||
|
||||
\b
|
||||
> [!NOTE]
|
||||
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
|
||||
"""
|
||||
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
|
||||
raise RuntimeError(
|
||||
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
|
||||
)
|
||||
metadata = load_notebook_metadata()
|
||||
_temp_dir = False
|
||||
if output_dir is None:
|
||||
_temp_dir = True
|
||||
output_dir = tempfile.mkdtemp(prefix='openllm-playground-')
|
||||
else:
|
||||
os.makedirs(os.path.abspath(os.path.expandvars(os.path.expanduser(output_dir))), exist_ok=True)
|
||||
|
||||
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
|
||||
for module in pkgutil.iter_modules(playground.__path__):
|
||||
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
|
||||
logger.debug(
|
||||
'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module'
|
||||
)
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder):
|
||||
continue
|
||||
termui.echo('Generating notebook for: ' + module.name, fg='magenta')
|
||||
markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]['description'])
|
||||
f = jupytext.read(os.path.join(module.module_finder.path, module.name + '.py'))
|
||||
f.cells.insert(0, markdown_cell)
|
||||
jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook')
|
||||
try:
|
||||
subprocess.check_output([
|
||||
sys.executable,
|
||||
'-m',
|
||||
'jupyter',
|
||||
'notebook',
|
||||
'--notebook-dir',
|
||||
output_dir,
|
||||
'--port',
|
||||
str(port),
|
||||
'--no-browser',
|
||||
'--debug',
|
||||
])
|
||||
except subprocess.CalledProcessError as e:
|
||||
termui.echo(e.output, fg='red')
|
||||
raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None
|
||||
except KeyboardInterrupt:
|
||||
termui.echo('\nShutting down Jupyter server...', fg='yellow')
|
||||
if _temp_dir:
|
||||
termui.echo('Note: You can access the generated notebooks in: ' + output_dir, fg='blue')
|
||||
ctx.exit(0)
|
||||
@@ -1,14 +0,0 @@
|
||||
This folder represents a playground and source of truth for a features
|
||||
development around OpenLLM
|
||||
|
||||
```bash
|
||||
openllm playground
|
||||
```
|
||||
|
||||
## Usage for developing this module
|
||||
|
||||
Write a python script that can be used by `jupytext` to convert it to a jupyter
|
||||
notebook via `openllm playground`
|
||||
|
||||
Make sure to add the new file and its documentation into
|
||||
[`_meta.yml`](./_meta.yml).
|
||||
@@ -1,36 +0,0 @@
|
||||
features:
|
||||
description: |
|
||||
## General introduction to OpenLLM.
|
||||
|
||||
This script will demo a few features from OpenLLM:
|
||||
|
||||
- Usage of Auto class abstraction and run prediction with `generate`
|
||||
- Ability to send per-requests parameters
|
||||
- Runner integration with BentoML
|
||||
opt_tuned:
|
||||
description: |
|
||||
## Fine tuning OPT
|
||||
|
||||
This script demonstrate how one can easily fine tune OPT
|
||||
with [LoRa](https://arxiv.org/abs/2106.09685) and in int8 with bitsandbytes.
|
||||
|
||||
It is based on one of the Peft examples fine tuning script.
|
||||
It requires at least one GPU to be available, so make sure to have it.
|
||||
falcon_tuned:
|
||||
description: |
|
||||
## Fine tuning Falcon
|
||||
|
||||
This script demonstrate how one can fine tune Falcon using [QLoRa](https://arxiv.org/pdf/2305.14314.pdf),
|
||||
[trl](https://github.com/lvwerra/trl).
|
||||
|
||||
It is trained using OpenAssistant's Guanaco [dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco)
|
||||
|
||||
It requires at least one GPU to be available, so make sure to have it.
|
||||
llama2_qlora:
|
||||
description: |
|
||||
## Fine tuning LlaMA 2
|
||||
This script demonstrate how one can fine tune Falcon using LoRA with [trl](https://github.com/lvwerra/trl)
|
||||
|
||||
It is trained using the [Dolly datasets](https://huggingface.co/datasets/databricks/databricks-dolly-15k)
|
||||
|
||||
It requires at least one GPU to be available.
|
||||
@@ -1,95 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
# Make sure to have at least one GPU to run this script
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# On notebook, make sure to install the following
|
||||
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
|
||||
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
DEFAULT_MODEL_ID = 'ybelkada/falcon-7b-sharded-bf16'
|
||||
DATASET_NAME = 'timdettmers/openassistant-guanaco'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
optim: str = dataclasses.field(default='paged_adamw_32bit')
|
||||
save_steps: int = dataclasses.field(default=10)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=500)
|
||||
logging_steps: int = dataclasses.field(default=10)
|
||||
learning_rate: float = dataclasses.field(default=2e-4)
|
||||
max_grad_norm: float = dataclasses.field(default=0.3)
|
||||
warmup_ratio: float = dataclasses.field(default=0.03)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
group_by_length: bool = dataclasses.field(default=True)
|
||||
lr_scheduler_type: str = dataclasses.field(default='constant')
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'falcon'))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
max_sequence_length: int = dataclasses.field(default=512)
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
llm = openllm.LLM(
|
||||
model_args.model_id, quantize='int4', bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
model, tokenizer = llm.prepare(
|
||||
'lora',
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias='none',
|
||||
target_modules=['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'],
|
||||
)
|
||||
model.config.use_cache = False
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split='train')
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=model_args.max_sequence_length,
|
||||
tokenizer=tokenizer,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
)
|
||||
|
||||
# upcast layernorm in float32 for more stable training
|
||||
for name, module in trainer.model.named_modules():
|
||||
if 'norm' in name:
|
||||
module = module.to(torch.float32)
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
@@ -1,44 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_NEW_TOKENS = 384
|
||||
|
||||
Q = 'Answer the following question, step by step:\n{q}\nA:'
|
||||
question = 'What is the meaning of life?'
|
||||
|
||||
|
||||
async def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('question', default=question)
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
args = parser.parse_args(args=[question])
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
llm = openllm.LLM[t.Any, t.Any]('facebook/opt-2.7b')
|
||||
prompt = Q.format(q=args.question)
|
||||
|
||||
logger.info('-' * 50, "Running with 'generate()'", '-' * 50)
|
||||
res = await llm.generate(prompt)
|
||||
logger.info('=' * 10, 'Response:', res)
|
||||
|
||||
logger.info('-' * 50, "Running with 'generate()' with per-requests argument", '-' * 50)
|
||||
res = await llm.generate(prompt, max_new_tokens=MAX_NEW_TOKENS)
|
||||
logger.info('=' * 10, 'Response:', res)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _mp_fn(index: t.Any): # type: ignore
|
||||
# For xla_spawn (TPUs)
|
||||
asyncio.run(main())
|
||||
@@ -1,236 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
|
||||
# Make sure to have at least one GPU to run this script
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# On notebook, make sure to install the following
|
||||
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
|
||||
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from random import randint, randrange
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
|
||||
def find_all_linear_names(model):
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, bnb.nn.Linear4bit):
|
||||
names = name.split('.')
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if 'lm_head' in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove('lm_head')
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
# Change this to the local converted path if you don't have access to the meta-llama model
|
||||
DEFAULT_MODEL_ID = 'meta-llama/Llama-2-7b-hf'
|
||||
# change this to 'main' if you want to use the latest llama
|
||||
DEFAULT_MODEL_VERSION = '335a02887eb6684d487240bbc28b5699298c3135'
|
||||
DATASET_NAME = 'databricks/databricks-dolly-15k'
|
||||
|
||||
|
||||
def format_dolly(sample):
|
||||
instruction = f"### Instruction\n{sample['instruction']}"
|
||||
context = f"### Context\n{sample['context']}" if len(sample['context']) > 0 else None
|
||||
response = f"### Answer\n{sample['response']}"
|
||||
# join all the parts together
|
||||
prompt = '\n\n'.join([i for i in [instruction, context, response] if i is not None])
|
||||
return prompt
|
||||
|
||||
|
||||
# template dataset to add prompt to each sample
|
||||
def template_dataset(sample, tokenizer):
|
||||
sample['text'] = f'{format_dolly(sample)}{tokenizer.eos_token}'
|
||||
return sample
|
||||
|
||||
|
||||
# empty list to save remainder from batches to use in next batch
|
||||
remainder = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
|
||||
|
||||
|
||||
def chunk(sample, chunk_length=2048):
|
||||
# define global remainder variable to save remainder from batches to use in next batch
|
||||
global remainder
|
||||
# Concatenate all texts and add remainder from previous batch
|
||||
concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
|
||||
concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
|
||||
# get total number of tokens for batch
|
||||
batch_total_length = len(concatenated_examples[next(iter(sample.keys()))])
|
||||
|
||||
# get max number of chunks for batch
|
||||
if batch_total_length >= chunk_length:
|
||||
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
|
||||
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
# add remainder to global variable for next batch
|
||||
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
|
||||
# prepare labels
|
||||
result['labels'] = result['input_ids'].copy()
|
||||
return result
|
||||
|
||||
|
||||
def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
|
||||
# Load dataset from the hub
|
||||
dataset = load_dataset(dataset_name, split='train')
|
||||
|
||||
print(f'dataset size: {len(dataset)}')
|
||||
print(dataset[randrange(len(dataset))])
|
||||
|
||||
# apply prompt template per sample
|
||||
dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features))
|
||||
# print random sample
|
||||
print('Sample from dolly-v2 ds:', dataset[randint(0, len(dataset))]['text'])
|
||||
|
||||
# tokenize and chunk dataset
|
||||
lm_dataset = dataset.map(
|
||||
lambda sample: tokenizer(sample['text']), batched=True, remove_columns=list(dataset.features)
|
||||
).map(partial(chunk, chunk_length=2048), batched=True)
|
||||
|
||||
# Print total number of samples
|
||||
print(f'Total number of samples: {len(lm_dataset)}')
|
||||
return lm_dataset
|
||||
|
||||
|
||||
def prepare_for_int4_training(
|
||||
model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
llm = openllm.LLM(
|
||||
model_id,
|
||||
revision=model_version,
|
||||
quantize='int4',
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
use_cache=not gradient_checkpointing,
|
||||
device_map='auto',
|
||||
)
|
||||
print('Model summary:', llm.model)
|
||||
|
||||
# get lora target modules
|
||||
modules = find_all_linear_names(llm.model)
|
||||
print(f'Found {len(modules)} modules to quantize: {modules}')
|
||||
|
||||
model, tokenizer = llm.prepare('lora', use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
|
||||
|
||||
# pre-process the model by upcasting the layer norms in float 32 for
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
if bf16:
|
||||
module = module.to(torch.bfloat16)
|
||||
if 'norm' in name:
|
||||
module = module.to(torch.float32)
|
||||
if 'lm_head' in name or 'embed_tokens' in name:
|
||||
if hasattr(module, 'weight'):
|
||||
if bf16 and module.weight.dtype == torch.float32:
|
||||
module = module.to(torch.bfloat16)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=1)
|
||||
gradient_checkpointing: bool = dataclasses.field(default=True)
|
||||
bf16: bool = dataclasses.field(default=torch.cuda.get_device_capability()[0] == 8)
|
||||
learning_rate: float = dataclasses.field(default=5e-5)
|
||||
num_train_epochs: int = dataclasses.field(default=3)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
report_to: str = dataclasses.field(default='none')
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'llama'))
|
||||
save_strategy: str = dataclasses.field(default='no')
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION)
|
||||
seed: int = dataclasses.field(default=42)
|
||||
merge_weights: bool = dataclasses.field(default=False)
|
||||
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
model_args, training_rags = ModelArguments(), TrainingArguments()
|
||||
else:
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(
|
||||
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
|
||||
)
|
||||
|
||||
# import the model first hand
|
||||
openllm.import_model(model_id=model_args.model_id, model_version=model_args.model_version)
|
||||
|
||||
|
||||
def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
import peft
|
||||
|
||||
transformers.set_seed(model_args.seed)
|
||||
|
||||
model, tokenizer = prepare_for_int4_training(
|
||||
model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16
|
||||
)
|
||||
datasets = prepare_datasets(tokenizer)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
if model_args.merge_weights:
|
||||
# note that this will requires larger GPU as we will load the whole model into memory
|
||||
|
||||
# merge adapter weights with base model and save
|
||||
# save int4 model
|
||||
trainer.model.save_pretrained(training_args.output_dir, safe_serialization=False)
|
||||
|
||||
# gc mem
|
||||
del model, trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(
|
||||
training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16
|
||||
)
|
||||
# merge lora with base weights and save
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained(
|
||||
os.path.join(os.getcwd(), 'outputs', 'merged_llama_lora'), safe_serialization=True, max_shard_size='2GB'
|
||||
)
|
||||
else:
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
|
||||
|
||||
train_loop(model_args, training_args)
|
||||
@@ -1,81 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
# Make sure to have at least one GPU to run this script
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# On notebook, make sure to install the following
|
||||
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from peft import PeftModel
|
||||
|
||||
DEFAULT_MODEL_ID = 'facebook/opt-6.7b'
|
||||
|
||||
|
||||
def load_trainer(
|
||||
model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments
|
||||
):
|
||||
return transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=dataset_dict['train'],
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=50)
|
||||
learning_rate: float = dataclasses.field(default=3e-4)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'opt'))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
llm = openllm.LLM(model_args.model_id, quantize='int8')
|
||||
model, tokenizer = llm.prepare(
|
||||
'lora', r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj'], lora_dropout=0.05, bias='none'
|
||||
)
|
||||
|
||||
# ft on english_quotes
|
||||
data = load_dataset('Abirate/english_quotes')
|
||||
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)
|
||||
|
||||
trainer = load_trainer(model, tokenizer, data, training_args)
|
||||
model.config.use_cache = False # silence just for warning, reenable for inference later
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
@@ -1,90 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
from openllm_core._typing_compat import DictStrAny, TypedDict
|
||||
from openllm_core.utils import get_debug_mode
|
||||
|
||||
logger = logging.getLogger('openllm')
|
||||
from _openllm_tiny import _termui
|
||||
|
||||
|
||||
class Level(enum.IntEnum):
|
||||
NOTSET = logging.DEBUG
|
||||
DEBUG = logging.DEBUG
|
||||
INFO = logging.INFO
|
||||
WARNING = logging.WARNING
|
||||
ERROR = logging.ERROR
|
||||
CRITICAL = logging.CRITICAL
|
||||
|
||||
@property
|
||||
def color(self) -> str | None:
|
||||
return {
|
||||
Level.NOTSET: None,
|
||||
Level.DEBUG: 'cyan',
|
||||
Level.INFO: 'green',
|
||||
Level.WARNING: 'yellow',
|
||||
Level.ERROR: 'red',
|
||||
Level.CRITICAL: 'red',
|
||||
}[self]
|
||||
|
||||
@classmethod
|
||||
def from_logging_level(cls, level: int) -> Level:
|
||||
return {
|
||||
logging.DEBUG: Level.DEBUG,
|
||||
logging.INFO: Level.INFO,
|
||||
logging.WARNING: Level.WARNING,
|
||||
logging.ERROR: Level.ERROR,
|
||||
logging.CRITICAL: Level.CRITICAL,
|
||||
}[level]
|
||||
def __dir__():
|
||||
return dir(_termui)
|
||||
|
||||
|
||||
class JsonLog(TypedDict):
|
||||
log_level: Level
|
||||
content: str
|
||||
|
||||
|
||||
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
|
||||
if get_debug_mode():
|
||||
echo(content, fg=fg)
|
||||
else:
|
||||
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
|
||||
|
||||
|
||||
warning = functools.partial(log, level=Level.WARNING)
|
||||
error = functools.partial(log, level=Level.ERROR)
|
||||
critical = functools.partial(log, level=Level.CRITICAL)
|
||||
debug = functools.partial(log, level=Level.DEBUG)
|
||||
info = functools.partial(log, level=Level.INFO)
|
||||
notset = functools.partial(log, level=Level.NOTSET)
|
||||
|
||||
|
||||
def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
|
||||
if json:
|
||||
text = orjson.loads(text)
|
||||
if 'content' in text and 'log_level' in text:
|
||||
content = text['content']
|
||||
fg = Level.from_logging_level(text['log_level']).color
|
||||
else:
|
||||
content = orjson.dumps(text).decode()
|
||||
fg = Level.INFO.color if not get_debug_mode() else Level.DEBUG.color
|
||||
else:
|
||||
content = t.cast(str, text)
|
||||
attrs['fg'] = fg
|
||||
|
||||
(click.echo if not _with_style else click.secho)(content, **attrs)
|
||||
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {
|
||||
'help_option_names': ['-h', '--help'],
|
||||
'max_content_width': COLUMNS,
|
||||
'token_normalize_func': inflection.underscore,
|
||||
}
|
||||
__all__ = ['COLUMNS', 'CONTEXT_SETTINGS', 'Level', 'critical', 'debug', 'echo', 'error', 'info', 'log', 'warning']
|
||||
def __getattr__(name):
|
||||
return getattr(_termui, name)
|
||||
|
||||
Reference in New Issue
Block a user